[SYSTEMDS-396] Distinct values count/estimation functions
New function for counting the number of distinct values in a
MatrixBlock. It is using the builtin AggregateInstructions to parse
through hop lop. It can be called to execute with different types of
estimators:
- count : The default implementation that counts by adding to an
hashmap.
Not memory efficient, but returns exact counts.
- KMV : An estimation algorithm K Minimum Values
- HLL : An estimation algorithm Hyper Log Log (Not finished)
Closes #909.
diff --git a/.github/workflows/functionsTests.yml b/.github/workflows/functionsTests.yml
index b983018..fb1a5bb 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -49,6 +49,7 @@
codegen,
codegenalg.partone,
codegenalg.parttwo,
+ countDistinct,
data.misc,
data.rand,
data.tensor,
diff --git a/dev/docs/Tasks.txt b/dev/docs/Tasks.txt
index 9a51eb5..f3d4acd 100644
--- a/dev/docs/Tasks.txt
+++ b/dev/docs/Tasks.txt
@@ -310,6 +310,7 @@
* 393 Builtin to find Connected Components of a graph OK
* 394 Builtin for one-hot encoding of matrix (not frame), see table OK
* 395 SVM rework and utils (confusionMatrix, msvmPredict) OK
+ * 396 Builtin for counting number of distinct values OK
SYSTEMDS-400 Spark Backend Improvements
* 401 Fix output block indexes of rdiag (diagM2V) OK
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 7345077..5ee7a79 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -178,6 +178,8 @@
TRACE("trace", false),
TO_ONE_HOT("toOneHot", true),
TYPEOF("typeOf", false),
+ COUNT_DISTINCT("countDistinct",false),
+ COUNT_DISTINCT_APPROX("countDistinctApprox",false),
VAR("var", false),
XOR("xor", false),
WINSORIZE("winsorize", true, false), //TODO parameterize w/ prob, min/max val
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 2d66e81..996132f 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -175,7 +175,9 @@
PROD, SUM_PROD,
MIN, MAX,
TRACE, MEAN, VAR,
- MAXINDEX, MININDEX;
+ MAXINDEX, MININDEX,
+ COUNT_DISTINCT,
+ COUNT_DISTINCT_APPROX;
@Override
public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 576a5e3..bfec9ff 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -19,15 +19,14 @@
package org.apache.sysds.lops;
-import org.apache.sysds.hops.HopsException;
-import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
-
-import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
+import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.lops.LopProperties.ExecType;
/**
@@ -77,7 +76,7 @@
input.addOutput(this);
lps.setProperties(inputs, et);
}
-
+
/**
* This method computes the location of "correction" terms in the output
* produced by PartialAgg instruction.
@@ -340,6 +339,18 @@
return "uaktrace";
break;
}
+
+ case COUNT_DISTINCT: {
+ if(dir == Direction.RowCol )
+ return "uacd";
+ break;
+ }
+
+ case COUNT_DISTINCT_APPROX: {
+ if(dir == Direction.RowCol )
+ return "uacdap";
+ break;
+ }
}
//should never come here for normal compilation
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 2c5d61a..96e2ebc 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -912,9 +912,10 @@
case NROW:
case NCOL:
case LENGTH:
+ case COUNT_DISTINCT:
+ case COUNT_DISTINCT_APPROX:
checkNumParameters(1);
- checkDataTypeParam(getFirstExpr(),
- DataType.MATRIX, DataType.FRAME, DataType.LIST);
+ checkDataTypeParam(getFirstExpr(), DataType.MATRIX);
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index de1e3ce..b8c7bcf 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2313,6 +2313,8 @@
case SUM:
case PROD:
case VAR:
+ case COUNT_DISTINCT:
+ case COUNT_DISTINCT_APPROX:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
break;
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 8cac05c..7e5b73a 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -50,7 +50,8 @@
public enum BuiltinCode { SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
- TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID }
+ TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID, COUNT_DISTINCT, COUNT_DISTINCT_APPROX}
+
public BuiltinCode bFunc;
private static final boolean FASTMATH = true;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index c613222..d4a78ef 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -107,6 +107,8 @@
String2CPInstructionType.put( "length" ,CPType.AggregateUnary);
String2CPInstructionType.put( "exists" ,CPType.AggregateUnary);
String2CPInstructionType.put( "lineage" ,CPType.AggregateUnary);
+ String2CPInstructionType.put( "uacd" , CPType.AggregateUnary);
+ String2CPInstructionType.put( "uacdap" , CPType.AggregateUnary);
String2CPInstructionType.put( "uaggouterchain", CPType.UaggOuterChain);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index 5f053e9..100251d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -31,9 +31,11 @@
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -42,7 +44,8 @@
public class AggregateUnaryCPInstruction extends UnaryCPInstruction
{
public enum AUType {
- NROW, NCOL, LENGTH, EXISTS, LINEAGE,
+ NROW, NCOL, LENGTH, EXISTS, LINEAGE,
+ COUNT_DISTINCT, COUNT_DISTINCT_APPROX,
DEFAULT;
public boolean isMeta() {
return this != DEFAULT;
@@ -72,6 +75,14 @@
|| opcode.equalsIgnoreCase("lineage")){
return new AggregateUnaryCPInstruction(new SimpleOperator(Builtin.getBuiltinFnObject(opcode)),
in1, out, AUType.valueOf(opcode.toUpperCase()), opcode, str);
+ }
+ else if(opcode.equalsIgnoreCase("uacd")){
+ return new AggregateUnaryCPInstruction(new SimpleOperator(null),
+ in1, out, AUType.COUNT_DISTINCT, opcode, str);
+ }
+ else if(opcode.equalsIgnoreCase("uacdap")){
+ return new AggregateUnaryCPInstruction(new SimpleOperator(null),
+ in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str);
}
else { //DEFAULT BEHAVIOR
AggregateUnaryOperator aggun = InstructionUtils
@@ -152,6 +163,17 @@
ec.setScalarOutput(output_name, new StringObject(Explain.explain(li)));
break;
}
+ case COUNT_DISTINCT:
+ case COUNT_DISTINCT_APPROX: {
+ if( !ec.getVariables().keySet().contains(input1.getName()) )
+ throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist.");
+ MatrixBlock input = ec.getMatrixInput(input1.getName());
+ CountDistinctOperator op = new CountDistinctOperator(_type);
+ int res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
+ ec.releaseMatrixInput(input1.getName());
+ ec.setScalarOutput(output_name, new IntObject(res));
+ break;
+ }
default: {
AggregateUnaryOperator au_op = (AggregateUnaryOperator) _optr;
if (input1.getDataType() == DataType.MATRIX) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
new file mode 100644
index 0000000..73e89e8
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.PriorityQueue;
+import java.util.Set;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
+import org.apache.sysds.utils.Hash;
+import org.apache.sysds.utils.Hash.HashType;
+
+/**
+ * This class contains various methods for counting the number of distinct values inside a MatrixBlock
+ */
+public class LibMatrixCountDistinct {
+
+ // ------------------------------
+ // Logging parameters:
+ // local debug flag
+ private static final boolean LOCAL_DEBUG = false;
+ // DEBUG/TRACE for details
+ private static final Level LOCAL_DEBUG_LEVEL = Level.DEBUG;
+
+ private static final Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
+
+ static {
+ // for internal debugging only
+ if(LOCAL_DEBUG) {
+ Logger.getLogger("org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct").setLevel(LOCAL_DEBUG_LEVEL);
+ }
+ }
+ // ------------------------------
+
+ /**
+ * The minimum number NonZero of cells in the input before using approximate techniques for counting number of
+ * distinct values.
+ */
+ public static int minimumSize = 1024;
+
+ private LibMatrixCountDistinct() {
+ // Prevent instantiation via private constructor.
+ }
+
+ /**
+ * Public method to count the number of distinct values inside a matrix. Depending on which CountDistinctOperator
+ * selected it either gets the absolute number or a estimated value.
+ *
+ * TODO: Support counting num distinct in rows, or columns axis.
+ *
+ * TODO: Add support for distributed spark operations
+ *
+ * TODO: If the MatrixBlock type is CompressedMatrix, simply read the vaules from the ColGroups.
+ *
+ * @param in the input matrix to count number distinct values in
+ * @param op the selected operator to use
+ * @return the distinct count
+ */
+ public static int estimateDistinctValues(MatrixBlock in, CountDistinctOperator op) {
+ int res = 0;
+ if(op.operatorType == CountDistinctTypes.KMV &&
+ (op.hashType == HashType.ExpHash || op.hashType == HashType.StandardJava)) {
+ throw new DMLException("Invalid hashing configuration using " + op.hashType + " and " + op.operatorType);
+ }
+ else if(op.operatorType == CountDistinctTypes.HLL) {
+ throw new NotImplementedException("HyperLogLog not implemented");
+ }
+ // shortcut in simplest case.
+ if( in.getLength() == 1 || in.isEmpty() )
+ return 1;
+ else if( in.getNonZeros() < minimumSize ) {
+ // Just use naive implementation if the number of nonZeros values size is small.
+ res = countDistinctValuesNaive(in);
+ }
+ else {
+ switch(op.operatorType) {
+ case COUNT:
+ res = countDistinctValuesNaive(in);
+ break;
+ case KMV:
+ res = countDistinctValuesKVM(in, op);
+ break;
+ default:
+ throw new DMLException("Invalid or not implemented Estimator Type");
+ }
+ }
+
+ if(res == 0)
+ throw new DMLRuntimeException("Imposible estimate of distinct values");
+ return res;
+ }
+
+ /**
+ * Naive implementation of counting Distinct values.
+ *
+ * Benefit Precise, but uses memory, on the scale of inputs number of distinct values.
+ *
+ * @param in The input matrix to count number distinct values in
+ * @return The absolute distinct count
+ */
+ private static int countDistinctValuesNaive(MatrixBlock in) {
+ Set<Double> distinct = new HashSet<>();
+
+ // TODO performance: direct sparse block /dense block access
+ if(in.isInSparseFormat()) {
+ Iterator<IJV> it = in.getSparseBlockIterator();
+ while(it.hasNext()) {
+ distinct.add(it.next().getV());
+ }
+ if( in.getNonZeros() < in.getLength() )
+ distinct.add(0d);
+ }
+ else {
+ //TODO fix for large dense blocks, where this call will fail
+ double[] data = in.getDenseBlockValues();
+ if(data == null) {
+ throw new DMLRuntimeException("Not valid execution");
+ }
+ //TODO avoid redundantly adding zero if not entirly dense
+ for(double v : data) {
+ distinct.add(v);
+ }
+ }
+ return distinct.size();
+ }
+
+ /**
+ * KMV synopsis(for k minimum values) Distinct-Value Estimation
+ *
+ * Kevin S. Beyer, Peter J. Haas, Berthold Reinwald, Yannis Sismanis, Rainer Gemulla:
+ *
+ * On synopses for distinctâvalue estimation under multiset operations. SIGMOD 2007
+ *
+ * TODO: Add multi-threaded version
+ *
+ * @param in The Matrix Block to estimate the number of distinct values in
+ * @return The distinct count estimate
+ */
+ private static int countDistinctValuesKVM(MatrixBlock in, CountDistinctOperator op) {
+
+ // D is the number of possible distinct values in the MatrixBlock.
+ // plus 1 to take account of 0 input.
+ long D = in.getNonZeros() + 1;
+
+ /**
+ * To ensure that the likelihood to hash to the same value we need O(D^2) positions to hash to assign. If the
+ * value is higher than int (which is the area we hash to) then use Integer Max value as largest hashing space.
+ */
+ long tmp = D * D;
+ int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+ LOG.debug("M not forced to int size: " + tmp);
+ LOG.debug("M: " + M);
+ /**
+ * The estimator is asymptotically unbiased as k becomes large, but memory usage also scales with k. Furthermore
+ * k value must be within range: D >> k >> 0
+ */
+ int k = D > 64 ? 64 : (int) D;
+ SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+
+ if(in.isInSparseFormat()) {
+ Iterator<IJV> it = in.getSparseBlockIterator();
+ while(it.hasNext()) {
+ double fullValue = it.next().getV();
+ int hash = Hash.hash(fullValue, op.hashType);
+ // Since Java does not have unsigned integer, the hash value is abs.
+ int v = (Math.abs(hash)) % (M - 1) + 1;
+ spq.add(v);
+ }
+ if( in.getNonZeros() < in.getLength() )
+ spq.add(Hash.hash(0d, op.hashType));
+ }
+ else {
+ //TODO fix for large dense blocks, where this call will fail
+ double[] data = in.getDenseBlockValues();
+ for(double fullValue : data) {
+ int hash = Hash.hash(fullValue, op.hashType);
+ int v = (Math.abs(hash)) % (M - 1) + 1;
+ spq.add(v);
+ }
+ }
+
+ LOG.debug("M: " + M);
+ LOG.debug("smallest hash:" + spq.peek());
+ LOG.debug("spq: " + spq.toString());
+
+ if(spq.size() < k) {
+ return spq.size();
+ }
+ else {
+ double U_k = (double) spq.poll() / (double) M;
+ LOG.debug("U_k : " + U_k);
+ double estimate = (double) (k - 1) / U_k;
+ LOG.debug("Estimate: " + estimate);
+ double ceilEstimate = Math.min(estimate, (double) D);
+ LOG.debug("Ceil worst case: " + ceilEstimate);
+ return (int) ceilEstimate;
+ }
+ }
+
+ /**
+ * Deceiving name, but is used to contain the k smallest values inserted.
+ *
+ * TODO: add utility method to join two partitions
+ *
+ * TODO: Replace Standard Java Set and Priority Queue with optimized versions.
+ */
+ private static class SmallestPriorityQueue {
+ private Set<Integer> containedSet;
+ private PriorityQueue<Integer> smallestHashes;
+ private int k;
+
+ public SmallestPriorityQueue(int k) {
+ smallestHashes = new PriorityQueue<>(k, Collections.reverseOrder());
+ containedSet = new HashSet<>(1);
+ this.k = k;
+ }
+
+ public void add(int v) {
+ if(!containedSet.contains(v)) {
+ if(smallestHashes.size() < k) {
+ smallestHashes.add(v);
+ containedSet.add(v);
+ }
+ else if(v < smallestHashes.peek()) {
+ LOG.trace(smallestHashes.peek() + " -- " + v);
+ smallestHashes.add(v);
+ containedSet.add(v);
+ containedSet.remove(smallestHashes.poll());
+ }
+ }
+ }
+
+ public int size() {
+ return smallestHashes.size();
+ }
+
+ public int peek() {
+ return smallestHashes.peek();
+ }
+
+ public int poll() {
+ return smallestHashes.poll();
+ }
+
+ @Override
+ public String toString() {
+ return smallestHashes.toString();
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
new file mode 100644
index 0000000..3f63ef9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.operators;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType;
+import org.apache.sysds.utils.Hash.HashType;
+
+public class CountDistinctOperator extends Operator {
+ private static final long serialVersionUID = 7615123453265129670L;
+
+ public final CountDistinctTypes operatorType;
+ public final HashType hashType;
+
+ public enum CountDistinctTypes { // The different supported types of counting.
+ COUNT, // Baseline naive implementation, iterate though, add to hashMap.
+ KMV, // K-Minimum Values algorithm.
+ HLL // HyperLogLog algorithm.
+ }
+
+ public CountDistinctOperator(AUType opType) {
+ super(true);
+ switch (opType) {
+ case COUNT_DISTINCT:
+ this.operatorType = CountDistinctTypes.COUNT;
+ break;
+ case COUNT_DISTINCT_APPROX:
+ this.operatorType = CountDistinctTypes.KMV;
+ break;
+ default:
+ throw new DMLRuntimeException(opType + " not supported for CountDistinct Operator");
+ }
+ this.hashType = HashType.LinearHash;
+ }
+
+ public CountDistinctOperator(CountDistinctTypes operatorType) {
+ super(true);
+ this.operatorType = operatorType;
+ this.hashType = HashType.StandardJava;
+ }
+
+ public CountDistinctOperator(CountDistinctTypes operatorType, HashType hashType) {
+ super(true);
+ this.operatorType = operatorType;
+ this.hashType = hashType;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
index dfb369d..3ac19f2 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
@@ -457,6 +457,26 @@
}
/**
+ * Converts an Integer matrix to an MatrixBlock
+ *
+ * @param data Int matrix input that is converted to double MatrixBlock
+ * @return The matrixBlock constructed.
+ */
+ public static MatrixBlock convertToMatrixBlock(int[][] data){
+ int rows = data.length;
+ int cols = (rows > 0)? data[0].length : 0;
+ MatrixBlock res = new MatrixBlock(rows, cols, false);
+ for(int row = 0; row< data.length; row++){
+ for(int col = 0; col < cols; col++){
+ double v = data[row][col];
+ if( v != 0 )
+ res.appendValue(row, col, v);
+ }
+ }
+ return res;
+ }
+
+ /**
* Creates a dense Matrix Block and copies the given double vector into it.
*
* @param data double array
diff --git a/src/main/java/org/apache/sysds/utils/Hash.java b/src/main/java/org/apache/sysds/utils/Hash.java
new file mode 100644
index 0000000..3bc3ca7
--- /dev/null
+++ b/src/main/java/org/apache/sysds/utils/Hash.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.utils;
+
+import org.apache.commons.lang3.NotImplementedException;
+
+/**
+ * A class containing different hashing functions.
+ */
+public class Hash {
+
+ /**
+ * Available Hashing techniques
+ */
+ public enum HashType {
+ StandardJava, LinearHash, ExpHash
+ }
+
+ /**
+ * A random Array (except first value) used for Linear and Exp hashing, to integer domain.
+ */
+ private static final int[] a = {0xFFFFFFFF, 0xB7825CBC, 0x10FA23F2, 0xD54E1532, 0x7590E53C, 0xECE6F631, 0x8954BF60,
+ 0x5BE38B88, 0xCA1D3AC0, 0xB2726F8E, 0xBADE7E7A, 0xCACD1184, 0xFB32BDAD, 0x2936C9D7, 0xB5B88D37, 0xD272D353,
+ 0xE139A063, 0xDACF6B87, 0x3568D521, 0x75C619EA, 0x7C2B8CBD, 0x012C3C7F, 0x0A621C37, 0x77274A12, 0x731D379A,
+ 0xE45E0D3B, 0xEAB4AE13, 0x10C440C7, 0x50CF2899, 0xD865BD46, 0xAABDF34F, 0x218FA0C3,};
+
+ /**
+ * Generic hashing of java objects, not ideal for specific values so use the specific methods for specific types.
+ *
+ * To Use the locality sensitive techniques override the objects hashcode function.
+ *
+ * @param o The Object to hash.
+ * @param ht The HashType to use.
+ * @return An int Hash value.
+ */
+ public static int hash(Object o, HashType ht) {
+ int hashcode = o.hashCode();
+ switch(ht) {
+ case StandardJava:
+ return hashcode;
+ case LinearHash:
+ return linearHash(hashcode);
+ case ExpHash:
+ return expHash(hashcode);
+ default:
+ throw new NotImplementedException("Not Implemented hashing combination");
+ }
+ }
+
+ /**
+ * Hash functions for double values.
+ *
+ * @param o The double value.
+ * @param ht The hashing function to apply.
+ * @return An int Hash value.
+ */
+ public static int hash(double o, HashType ht) {
+ switch(ht) {
+ case StandardJava:
+ // Here just for reference
+ return new Double(o).hashCode();
+ case LinearHash:
+ // Altho Linear Hashing is locality sensitive, it is not in this case
+ // since the bit positions for the double value is split in exponent and mantissa.
+ // If the locality sensitive aspect is required use linear hash on an double value rounded to integer.
+ long v = Double.doubleToLongBits(o);
+ return linearHash((int) (v ^ (v >>> 32)));
+ default:
+ throw new NotImplementedException("Not Implemented hashing combination for double value");
+ }
+ }
+
+ /**
+ * Compute the Linear hash of an int input value.
+ *
+ * @param v The value to hash.
+ * @return The int hash.
+ */
+ public static int linearHash(int v) {
+ return linearHash(v, a.length);
+ }
+
+ /**
+ * Compute the Linear hash of an int input value, but only use the first bits of the linear hash.
+ *
+ * @param v The value to hash.
+ * @param bits The number of bits to use. up to maximum of 32.
+ * @return The hashed value
+ */
+ public static int linearHash(int v, int bits) {
+ int res = 0;
+ for(int i = 0; i < bits; i++) {
+ res = (res << 1) + (Long.bitCount(a[i] & v) & 1);
+ }
+ return res;
+ }
+
+ /**
+ * Compute exponentially distributed hash values in range 0..a.length
+ *
+ * eg: 50% == 0 , 25% == 1 12.5 % == 2 etc.
+ *
+ * Useful because you can estimate size of a collection by only maintaining the highest value found. from this hash.
+ *
+ * @param x value to hash
+ * @return a hash value byte (only in the range of 0 to a.length)
+ */
+ public static byte expHash(int x) {
+ for(int value = 0; value < a.length; value++) {
+ int dot = Long.bitCount(a[value] & x) & 1;
+ if(dot != 0)
+ return (byte) (value + 1);
+ }
+ return (byte) a.length;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java
index 0596668..8696442 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -779,7 +779,7 @@
*
* @param d1 The expected value.
* @param d2 The actual value.
- * @return Whether they are equal or not.
+ * @return Whether distance in bits
*/
public static long compareScalarBits(double d1, double d2) {
long expectedBits = Double.doubleToLongBits(d1) < 0 ? 0x8000000000000000L - Double.doubleToLongBits(d1) : Double.doubleToLongBits(d1);
@@ -1386,6 +1386,42 @@
}
/**
+ *
+ * Generates a test matrix, but only containing real numbers, in the range specified.
+ *
+ * @param rows number of rows
+ * @param cols number of columns
+ * @param min minimum value whole number
+ * @param max maximum value whole number
+ * @param sparsity sparsity
+ * @param seed seed
+ * @return random matrix containing whole numbers in the range specified.
+ */
+ public static int[][] generateTestMatrixIntV(int rows, int cols, int min, int max, double sparsity, long seed) {
+ int[][] matrix = new int[rows][cols];
+ Random random = (seed == -1) ? TestUtils.random : new Random(seed);
+ if (max - min != 0){
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < cols; j++) {
+ if (random.nextDouble() > sparsity)
+ continue;
+ matrix[i][j] = (random.nextInt((max - min)) + min);
+ }
+ }
+ } else{
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < cols; j++) {
+ if (random.nextDouble() > sparsity)
+ continue;
+ matrix[i][j] = max;
+ }
+ }
+ }
+
+ return matrix;
+ }
+
+ /**
* <p>
* Generates a test matrix with the specified parameters as a two
* dimensional array. The matrix will not contain any zero values.
diff --git a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
new file mode 100644
index 0000000..a8e3e2b
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Hash.HashType;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class CountDistinctTest {
+
+ private static CountDistinctTypes[] esT = new CountDistinctTypes[] {
+ // The different types of Estimators
+ CountDistinctTypes.COUNT,
+ CountDistinctTypes.KMV,
+ CountDistinctTypes.HLL
+ };
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ ArrayList<Object[]> tests = new ArrayList<>();
+ ArrayList<MatrixBlock> inputs = new ArrayList<>();
+ ArrayList<Long> actualUnique = new ArrayList<>();
+
+ // single value matrix.
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrix(1, 1, 0.0, 100.0, 1, 7)));
+ actualUnique.add(1L);
+
+ // single column or row matrix.
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrix(1, 100, 0.0, 100.0, 1, 7)));
+ actualUnique.add(100L);
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrix(100, 1, 0.0, 100.0, 1, 7)));
+ actualUnique.add(100L);
+
+ // Sparse Multicol random values (most likely each value is unique)
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrix(100, 10, 0.0, 100.0, 0.1, 7)));
+ actualUnique.add(98L); //dense representation
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrix(100, 1000, 0.0, 100.0, 0.1, 7)));
+ actualUnique.add(9823L+1); //sparse representation
+
+ // MultiCol Inputs (using integers)
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(5000, 5000, 1, 100, 1, 8)));
+ actualUnique.add(99L);
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024, 10240, 1, 100, 1, 7)));
+ actualUnique.add(99L);
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(10240, 1024, 1, 100, 1, 7)));
+ actualUnique.add(99L);
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024, 10241, 1, 1500, 1, 7)));
+ actualUnique.add(1499L);
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024, 10241, 0, 3000, 1, 7)));
+ actualUnique.add(3000L);
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024, 10241, 0, 6000, 1, 7)));
+ actualUnique.add(6000L);
+
+ // Sparse Inputs
+ inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024, 10241, 0, 3000, 0.1, 7)));
+ actualUnique.add(3000L);
+ // inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(10240, 10241, 0, 5000, 0.1, 7)));
+ // actualUnique.add(5000L);
+ // inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(10240, 10241, 0, 10000, 0.1, 7)));
+ // actualUnique.add(10000L);
+
+ for(CountDistinctTypes et : esT) {
+ for(HashType ht : HashType.values()) {
+ if((ht == HashType.ExpHash && et == CountDistinctTypes.KMV) ||
+ (ht == HashType.StandardJava && et == CountDistinctTypes.KMV)) {
+ String errorMessage = "Invalid hashing configuration using " + ht + " and " + et;
+ tests.add(new Object[] {et, inputs.get(0), actualUnique.get(0), ht, DMLException.class,
+ errorMessage, 0.0});
+ }
+ else if(et == CountDistinctTypes.HLL) {
+ tests.add(new Object[] {et, inputs.get(0), actualUnique.get(0), ht, NotImplementedException.class,
+ "HyperLogLog not implemented", 0.0});
+ }
+ else if (et != CountDistinctTypes.COUNT) {
+ for(int i = 0; i < inputs.size(); i++) {
+ // allowing the estimate to be 15% off
+ tests.add(new Object[] {et, inputs.get(i), actualUnique.get(i), ht, null, null, 0.15});
+ }
+ }
+ }
+ if (et == CountDistinctTypes.COUNT){
+ for(int i = 0; i < inputs.size(); i++) {
+ tests.add(new Object[] {et, inputs.get(i), actualUnique.get(i), null, null, null, 0.0001});
+ }
+ }
+ }
+ return tests;
+ }
+
+ @Parameterized.Parameter
+ public CountDistinctTypes et;
+ @Parameterized.Parameter(1)
+ public MatrixBlock in;
+ @Parameterized.Parameter(2)
+ public long nrUnique;
+ @Parameterized.Parameter(3)
+ public HashType ht;
+
+ // Exception handling
+ @Parameterized.Parameter(4)
+ public Class<? extends Exception> expectedException;
+ @Parameterized.Parameter(5)
+ public String expectedExceptionMsg;
+
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ // allowing the estimate to be within 20% of target.
+ @Parameterized.Parameter(6)
+ public double epsilon;
+
+ @Test
+ public void testEstimation() {
+
+ // setup expected exception
+ if(expectedException != null) {
+ thrown.expect(expectedException);
+ thrown.expectMessage(expectedExceptionMsg);
+ }
+
+ Integer out = 0;
+ CountDistinctOperator op = new CountDistinctOperator(et, ht);
+ try {
+ out = LibMatrixCountDistinct.estimateDistinctValues(in, op);
+ }
+ catch(DMLException e) {
+ throw e;
+ }
+ catch(NotImplementedException e) {
+ throw e;
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ Assert.assertTrue(this.toString(), false);
+ }
+
+ int count = out;
+ boolean success = Math.abs(nrUnique - count) <= nrUnique * epsilon;
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.toString());
+ sb.append("\n" + count + " unique values, actual:" + nrUnique + " with eps of " + epsilon);
+ Assert.assertTrue(sb.toString(), success);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(et);
+ if(ht != null){
+ sb.append("-" + ht);
+ }
+ sb.append(" nrUnique:" + nrUnique);
+ sb.append(" & input size:" + in.getNumRows() + "," + in.getNumColumns());
+ sb.append(" sparse: " + in.isInSparseFormat());
+ if(expectedException != null) {
+ sb.append("\nExpected Exception: " + expectedException);
+ sb.append("\nExpected Exception Msg: " + expectedExceptionMsg);
+ }
+ return sb.toString();
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/component/misc/UtilHash.java b/src/test/java/org/apache/sysds/test/component/misc/UtilHash.java
new file mode 100644
index 0000000..0e07d6c
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/misc/UtilHash.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.misc;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Hash;
+import org.apache.sysds.utils.Hash.HashType;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class UtilHash {
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ ArrayList<Object[]> tests = new ArrayList<>();
+ tests.add(new Object[] {100, 2, 0.0, 1.0});
+ tests.add(new Object[] {100, 5, Double.MIN_VALUE, Double.MAX_VALUE});
+ tests.add(new Object[] {1000, 50, Double.MIN_VALUE, Double.MAX_VALUE});
+ tests.add(new Object[] {10000, 500, Double.MIN_VALUE, Double.MAX_VALUE});
+ tests.add(new Object[] {1000, 500, Double.MIN_VALUE, Double.MAX_VALUE});
+ tests.add(new Object[] {1000, 500, 0.0, 1.0});
+ tests.add(new Object[] {1000, 500, 0.0, 100.0});
+ tests.add(new Object[] {1000, 500, 0.0, 0.0000001});
+ tests.add(new Object[] {1000, 1000, 0.0, 0.00000001});
+ tests.add(new Object[] {1000000, 1000000, 0.0, 0.00000001});
+
+ ArrayList<Object[]> actualTests = new ArrayList<>();
+
+ Set<HashType> validHashTypes = new HashSet<>();
+ for(HashType ht : HashType.values()) validHashTypes.add(ht);
+ validHashTypes.remove(HashType.ExpHash);
+
+ for(HashType ht : validHashTypes) {
+ for(int i = 0; i < tests.size(); i++) {
+ actualTests.add(new Object[] {tests.get(i)[0], tests.get(i)[1], tests.get(i)[2], tests.get(i)[3], ht});
+ }
+ }
+
+ return actualTests;
+ }
+
+ @Parameterized.Parameter
+ public int nrKeys = 1000;
+ @Parameterized.Parameter(1)
+ public int nrBuckets = 50;
+ @Parameterized.Parameter(2)
+ public double min;
+ @Parameterized.Parameter(3)
+ public double max;
+ @Parameterized.Parameter(4)
+ public HashType ht;
+
+ private double epsilon = 0.05;
+
+ @Test
+ public void chiSquaredTest() {
+ // https://en.wikipedia.org/wiki/Hash_function#Uniformity
+
+ double[] input = TestUtils.generateTestMatrix(1, nrKeys, min, max, 1.0, 10)[0];
+
+ int[] buckets = new int[nrBuckets];
+
+ for(double x : input) {
+ int hv = Hash.hash(new Double(x), ht);
+ buckets[Math.abs(hv % nrBuckets)] += 1;
+ }
+
+ double top = 0;
+ for(int b : buckets) {
+ top += (double) (b) * (double) (b + 1.0) / 2.0;
+ }
+
+ double res = top / ((nrKeys / (2.0 * nrBuckets)) * (nrKeys + 2.0 * nrBuckets - 1));
+
+ boolean success = Math.abs(res - 1) <= epsilon;
+
+ Assert.assertTrue("Chi squared hashing test: " + res + " should be close to 1, with hashing: " + ht, success);
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFactorizationTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFactorizationTest.java
index 6510b6d..1da15bb 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFactorizationTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinFactorizationTest.java
@@ -111,7 +111,7 @@
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
//generate input and write incl meta data
- double[][] Xa = TestUtils.generateTestMatrix(rows, cols, 1, 10, sparsity, 7);
+ double[][] Xa = TestUtils.generateTestMatrix(rows, cols, 1.0, 10.0, sparsity, 7);
writeInputMatrixWithMTD("X", Xa, true);
//run test case
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
new file mode 100644
index 0000000..74772e0
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.lops.LopProperties;
+import org.junit.Test;
+
+public class CountDistinct extends CountDistinctBase {
+
+ public String TEST_NAME = "countDistinct";
+ public String TEST_DIR = "functions/countDistinct/";
+ public String TEST_CLASS_DIR = TEST_DIR + CountDistinct.class.getSimpleName() + "/";
+
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+
+ @Test
+ public void testSimple1by1() {
+ // test simple 1 by 1.
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ countDistinctTest(1, 1, 1, ex, 0.00001);
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
new file mode 100644
index 0000000..8d0d242
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.lops.LopProperties;
+import org.junit.Test;
+
+public class CountDistinctApprox extends CountDistinctBase {
+
+ private final static String TEST_NAME = "countDistinctApprox";
+ private final static String TEST_DIR = "functions/countDistinct/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApprox.class.getSimpleName() + "/";
+
+ public CountDistinctApprox(){
+ percentTolerance = 0.1;
+ }
+
+ @Test
+ public void testXXLarge() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = 9000 * percentTolerance;
+ countDistinctTest(9000, 10000, 5000, ex, tolerance);
+ }
+
+ @Override
+ protected String getTestClassDir() {
+ return TEST_CLASS_DIR;
+ }
+
+ @Override
+ protected String getTestName() {
+ return TEST_NAME;
+ }
+
+ @Override
+ protected String getTestDir() {
+ return TEST_DIR;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
new file mode 100644
index 0000000..6a9b096
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public abstract class CountDistinctBase extends AutomatedTestBase {
+
+ protected abstract String getTestClassDir();
+ protected abstract String getTestName();
+ protected abstract String getTestDir();
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(getTestName(), new TestConfiguration(getTestClassDir(), getTestName(), new String[] { "A.scalar" }));
+ }
+
+ protected double percentTolerance = 0.0;
+ protected double baseTolerance = 0.0001;
+
+ @Test
+ public void testSmall() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = baseTolerance + 50 * percentTolerance;
+ countDistinctTest(50, 50, 50, ex,tolerance);
+ }
+
+ @Test
+ public void testLarge() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = baseTolerance + 800 * percentTolerance;
+ countDistinctTest(800, 1000, 1000, ex,tolerance);
+ }
+
+ @Test
+ public void testXLarge() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = baseTolerance + 1723 * percentTolerance;
+ countDistinctTest(1723, 5000, 2000, ex,tolerance);
+ }
+
+ @Test
+ public void test1Unique() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = 0.00001;
+ countDistinctTest(1, 100, 1000, ex,tolerance);
+ }
+
+ @Test
+ public void test2Unique() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = 0.00001;
+ countDistinctTest(2, 100, 1000, ex,tolerance);
+ }
+
+ @Test
+ public void test120Unique() {
+ LopProperties.ExecType ex = LopProperties.ExecType.CP;
+ double tolerance = 0.00001 + 120 * percentTolerance;
+ countDistinctTest(120, 100, 1000, ex,tolerance);
+ }
+
+ public void countDistinctTest(int numberDistinct, int cols, int rows, LopProperties.ExecType instType, double tolerance) {
+ Types.ExecMode platformOld = setExecMode(instType);
+ try {
+ loadTestConfiguration(getTestConfiguration(getTestName()));
+ String HOME = SCRIPT_DIR + getTestDir();
+ fullDMLScriptName = HOME + getTestName() + ".dml";
+ String out = output("A");
+ System.out.println(out);
+ programArgs = new String[] { "-args", String.valueOf(numberDistinct), String.valueOf(rows),
+ String.valueOf(cols), out};
+
+ runTest(true, false, null, -1);
+ writeExpectedScalar("A", numberDistinct);
+ compareResults(tolerance);
+ } catch (Exception e) {
+ e.printStackTrace();
+ assertTrue("Exception in execution: " + e.getMessage(), false);
+ } finally {
+ rtplatform = platformOld;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinct.dml
new file mode 100644
index 0000000..a12ffe2
--- /dev/null
+++ b/src/test/scripts/functions/countDistinct/countDistinct.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, seed = 7))
+res = countDistinct(input)
+write(res, $4, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinctApprox.dml b/src/test/scripts/functions/countDistinct/countDistinctApprox.dml
new file mode 100644
index 0000000..e8b964e
--- /dev/null
+++ b/src/test/scripts/functions/countDistinct/countDistinctApprox.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, seed = 7))
+res = countDistinctApprox(input)
+write(res, $4, format="text")
\ No newline at end of file