Merge pull request #3 from DataSketches/svd_merge
major refactor
diff --git a/pom.xml b/pom.xml
index 96777d7..767e7c1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -93,6 +93,16 @@
<name>bintray</name>
<url>https://jcenter.bintray.com</url>
</repository>
+ <repository>
+ <id>sonatype-releases</id>
+ <url>https://oss.sonatype.org/content/repositories/releases/</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
</repositories>
<distributionManagement>
@@ -137,9 +147,15 @@
<dependencies>
<dependency>
+ <groupId>com.googlecode.matrix-toolkits-java</groupId>
+ <artifactId>mtj</artifactId>
+ <version>1.0.4</version>
+ </dependency>
+
+ <dependency>
<groupId>org.ojalgo</groupId>
<artifactId>ojalgo</artifactId>
- <version>44.0.0</version>
+ <version>45.0.0</version>
</dependency>
<dependency>
diff --git a/src/main/java/com/yahoo/sketches/vector/MatrixFamily.java b/src/main/java/com/yahoo/sketches/vector/MatrixFamily.java
index 39878bf..22a6954 100644
--- a/src/main/java/com/yahoo/sketches/vector/MatrixFamily.java
+++ b/src/main/java/com/yahoo/sketches/vector/MatrixFamily.java
@@ -23,7 +23,7 @@
*/
public enum MatrixFamily {
/**
- * The Frequent Directions sketch is used for approximate Singular Value Decomposition (SVD) of a
+ * The Frequent Directions sketch is used for approximate Singular Value Decomposition (MatrixOps) of a
* matrix.
*/
MATRIX(128, "Matrix", 2, 3),
diff --git a/src/main/java/com/yahoo/sketches/vector/decomposition/FrequentDirections.java b/src/main/java/com/yahoo/sketches/vector/decomposition/FrequentDirections.java
index f4a33ad..04c8a9f 100644
--- a/src/main/java/com/yahoo/sketches/vector/decomposition/FrequentDirections.java
+++ b/src/main/java/com/yahoo/sketches/vector/decomposition/FrequentDirections.java
@@ -27,17 +27,12 @@
import static com.yahoo.sketches.vector.decomposition.PreambleUtil.insertSVAdjustment;
import static com.yahoo.sketches.vector.decomposition.PreambleUtil.insertSerVer;
-import org.ojalgo.array.Array1D;
-import org.ojalgo.matrix.decomposition.SingularValue;
-import org.ojalgo.matrix.store.MatrixStore;
-import org.ojalgo.matrix.store.PrimitiveDenseStore;
-import org.ojalgo.matrix.store.SparseStore;
-
import com.yahoo.memory.Memory;
import com.yahoo.memory.WritableMemory;
import com.yahoo.sketches.vector.MatrixFamily;
import com.yahoo.sketches.vector.matrix.Matrix;
import com.yahoo.sketches.vector.matrix.MatrixBuilder;
+import com.yahoo.sketches.vector.matrix.MatrixType;
/**
* This class implements the Frequent Directions algorithm proposed by Edo Liberty in "Simple and
@@ -48,6 +43,9 @@
* @author Jon Malkin
*/
public final class FrequentDirections {
+ private static final MatrixType DEFAULT_MATRIX_TYPE = MatrixType.OJALGO;
+ private static final SVDAlgo DEFAULT_SVD_ALGO = SVDAlgo.SYM;
+
private final int k_;
private final int l_;
private final int d_;
@@ -55,20 +53,33 @@
private double svAdjustment_;
- private PrimitiveDenseStore B_;
- transient private int nextZeroRow_;
+ private Matrix B_;
- transient private final double[] sv_; // pre-allocated to fetch singular values
- transient private final SparseStore<Double> S_; // to hold singular value matrix
+ private SVDAlgo algo_ = DEFAULT_SVD_ALGO;
+
+ transient private int nextZeroRow_;
+ transient private MatrixOps svd_; // avoids re-initializing
/**
- * Creates a new instance of a Frequent Directions sketch.
+ * Creates a new instance of a Frequent Directions sketch using the default Linear Algebra backing library
* @param k Number of dimensions (rows) in the sketch output
* @param d Number of dimensions per input vector (columns)
* @return An empty Frequent Directions sketch
*/
public static FrequentDirections newInstance(final int k, final int d) {
- return new FrequentDirections(k, d);
+ return newInstance(k, d, DEFAULT_MATRIX_TYPE);
+ }
+
+ /**
+ * Creates a new instance of a Frequent Directions sketch using a specific MatrixType
+ * Package-private until (if ever) MTJ works properly.
+ * @param k Number of dimensions (rows) in the sketch output
+ * @param d Number of dimensions per input vector (columns)
+ * @param type MatrixType to use for backing matrix. Impacts choice of SVD library.
+ * @return An empty Frequent Directions sketch
+ */
+ static FrequentDirections newInstance(final int k, final int d, final MatrixType type) {
+ return new FrequentDirections(k, d, null, type);
}
/**
@@ -77,6 +88,17 @@
* @return A Frequent Directions sketch
*/
public static FrequentDirections heapify(final Memory srcMem) {
+ return heapify(srcMem, DEFAULT_MATRIX_TYPE);
+ }
+
+ /**
+ * Instantiates a Frequent Directions sketch from a serialized image using a specific MatrixType.
+ * Package-private until (if ever) MTJ works properly.
+ * @param srcMem Memory containing the serialized image of a Frequent Directions sketch
+ * @param type The MatrixType to use with this instance
+ * @return A Frequent Directions sketch
+ */
+ static FrequentDirections heapify(final Memory srcMem, final MatrixType type) {
final int preLongs = getAndCheckPreLongs(srcMem);
final int serVer = extractSerVer(srcMem);
if (serVer != SER_VER) {
@@ -100,11 +122,10 @@
final long offsetBytes = preLongs * Long.BYTES;
final long mtxBytes = srcMem.getCapacity() - offsetBytes;
- final Matrix B = Matrix.heapify(srcMem.region(offsetBytes, mtxBytes), MatrixBuilder.Algo.OJALGO);
+ final Matrix B = Matrix.heapify(srcMem.region(offsetBytes, mtxBytes), type);
assert B != null;
- final FrequentDirections fd
- = new FrequentDirections(k, d, (PrimitiveDenseStore) B.getRawObject());
+ final FrequentDirections fd = new FrequentDirections(k, d, B, B.getMatrixType());
fd.n_ = extractN(srcMem);
fd.nextZeroRow_ = numRows;
fd.svAdjustment_ = extractSVAdjustment(srcMem);
@@ -113,10 +134,11 @@
}
private FrequentDirections(final int k, final int d) {
- this(k, d, null);
+ this(k, d, null, DEFAULT_MATRIX_TYPE);
}
- private FrequentDirections(final int k, final int d, final PrimitiveDenseStore B) {
+ // uses MatrixType of B, if present, otherwise falls back to type input
+ private FrequentDirections(final int k, final int d, final Matrix B, final MatrixType type) {
if (k < 1) {
throw new IllegalArgumentException("Number of projected dimensions must be at least 1");
}
@@ -138,14 +160,10 @@
n_ = 0;
if (B == null) {
- B_ = PrimitiveDenseStore.FACTORY.makeZero(l_, d_);
+ B_ = new MatrixBuilder().setType(type).build(l_, d_);
} else {
B_ = B;
}
-
- final int svDim = Math.min(l_, d_);
- sv_ = new double[svDim];
- S_ = SparseStore.makePrimitive(svDim, svDim);
}
/**
@@ -166,10 +184,7 @@
reduceRank();
}
- // dense input so set all values
- for (int i = 0; i < vector.length; ++i) {
- B_.set(nextZeroRow_, i, vector[i]);
- }
+ B_.setRow(nextZeroRow_, vector);
++n_;
++nextZeroRow_;
@@ -194,11 +209,7 @@
reduceRank();
}
- final Array1D<Double> rv = fd.B_.sliceRow(m);
- for (int i = 0; i < rv.count(); ++i) {
- B_.set(nextZeroRow_, i, rv.get(i));
- }
-
+ B_.setRow(nextZeroRow_, fd.B_.getRow(m));
++nextZeroRow_;
}
@@ -233,6 +244,14 @@
public long getN() { return n_; }
/**
+ * Sets the SVD algorithm to use, allowing exact or approximate computation. @see SVDAlgo for details.
+ * @param algo The SVDAlgo type to use
+ */
+ public void setSVDAlgo(final SVDAlgo algo) {
+ algo_ = algo;
+ }
+
+ /**
* Returns the singular values of the sketch, adjusted for the mass subtracted off during the
* algorithm.
* @return An array of singular values.
@@ -249,17 +268,19 @@
* @return An array of singular values.
*/
public double[] getSingularValues(final boolean compensative) {
- final SingularValue<Double> svd = SingularValue.make(B_);
- svd.compute(B_);
- svd.getSingularValues(sv_);
+ if (svd_ == null) {
+ svd_ = MatrixOps.newInstance(B_, algo_, k_);
+ }
- double medianSVSq = sv_[k_ - 1]; // (l_/2)th item, not yet squared
+ final double[] sv = svd_.getSingularValues(B_);
+
+ double medianSVSq = sv[k_ - 1]; // (l_/2)th item, not yet squared
medianSVSq *= medianSVSq;
final double tmpSvAdj = svAdjustment_ + medianSVSq;
final double[] svList = new double[k_];
for (int i = 0; i < k_ - 1; ++i) {
- final double val = sv_[i];
+ final double val = sv[i];
double adjSqSV = val * val - medianSVSq;
if (compensative) { adjSqSV += tmpSvAdj; }
svList[i] = adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV);
@@ -269,22 +290,16 @@
}
/**
- * Returns an orthonormal projection Matrix that can be used to project input vectors into the
+ * Returns an orthonormal projection Matrix V^T that can be used to project input vectors into the
* k-dimensional space represented by the sketch.
* @return An orthonormal Matrix object
*/
public Matrix getProjectionMatrix() {
- final SingularValue<Double> svd = SingularValue.make(B_);
- svd.compute(B_);
- final MatrixStore<Double> m = svd.getQ2().transpose();
-
- // not super efficient...
- final Matrix result = Matrix.builder().build(k_, d_);
- for (int i = 0; i < k_ - 1; ++i) { // last SV is 0
- result.setRow(i, m.sliceRow(i).toRawCopy1D());
+ if (svd_ == null) {
+ svd_ = MatrixOps.newInstance(B_, algo_, k_);
}
- return result;
+ return svd_.getVt();
}
/**
@@ -308,44 +323,28 @@
/**
* Returns a Matrix with the current state of the sketch. Call <tt>trim()</tt> first to ensure
- * no more than k rows. If compensative, uses only the top k singular values.
- * @param compensative If true, applies adjustment to singular values based on the cumulative
- * weight subtracted off
- * @return A Matrix representing the data in this sketch
+ * no more than k rows. If compensative, uses only the top k singular values. If not applying compensation
+ * factor, this method returns the actual data object meaning any changes to the result data will corrupt
+ * the sketch.
+ * @param compensative If true, returns a copy of the data matrix after applying adjustment to singular
+ * values based on the cumulative weight subtracted off. If false, returns the actual
+ * data matrix.
+ * @return A Matrix of the data in this sketch, possibly adjusted by compensating for subtracted weight.
*/
public Matrix getResult(final boolean compensative) {
if (isEmpty()) {
return null;
}
- final PrimitiveDenseStore result = PrimitiveDenseStore.FACTORY.makeZero(nextZeroRow_, d_);
-
if (compensative) {
- final SingularValue<Double> svd = SingularValue.make(B_);
- svd.compute(B_);
- svd.getSingularValues(sv_);
-
- for (int i = 0; i < k_ - 1; ++i) {
- final double val = sv_[i];
- final double adjSV = Math.sqrt(val * val + svAdjustment_);
- S_.set(i, i, adjSV);
- }
- for (int i = k_ - 1; i < S_.countColumns(); ++i) {
- S_.set(i, i, 0.0);
+ if (svd_ == null) {
+ svd_ = MatrixOps.newInstance(B_, algo_, k_);
}
- S_.multiply(svd.getQ2().transpose(), result);
+ return svd_.applyAdjustment(B_, svAdjustment_);
} else {
- // there's gotta be a better way to copy rows than this
- for (int i = 0; i < nextZeroRow_; ++i) {
- int j = 0;
- for (double d : B_.sliceRow(i)) {
- result.set(i, j++, d);
- }
- }
+ return B_;
}
-
- return Matrix.wrap(result);
}
/**
@@ -365,13 +364,11 @@
final boolean empty = isEmpty();
final int familyId = MatrixFamily.FREQUENTDIRECTIONS.getID();
- final Matrix wrapB = Matrix.wrap(B_);
-
final int preLongs = empty
? MatrixFamily.FREQUENTDIRECTIONS.getMinPreLongs()
: MatrixFamily.FREQUENTDIRECTIONS.getMaxPreLongs();
- final int mtxBytes = empty ? 0 : wrapB.getCompactSizeBytes(nextZeroRow_, d_);
+ final int mtxBytes = empty ? 0 : B_.getCompactSizeBytes(nextZeroRow_, d_);
final int outBytes = (preLongs * Long.BYTES) + mtxBytes;
final byte[] outArr = new byte[outBytes];
@@ -395,7 +392,7 @@
insertSVAdjustment(memObj, memAddr, svAdjustment_);
memOut.putByteArray(preLongs * Long.BYTES,
- wrapB.toCompactByteArray(nextZeroRow_, d_), 0, mtxBytes);
+ B_.toCompactByteArray(nextZeroRow_, d_), 0, mtxBytes);
return outArr;
}
@@ -462,30 +459,29 @@
return sb.toString();
}
- final Matrix mtx = Matrix.wrap(B_);
- final int tmpColDim = (int) mtx.getNumColumns();
+ final int tmpColDim = (int) B_.getNumColumns();
sb.append(" Matrix data :").append(LS);
- sb.append(mtx.getClass().getName());
+ sb.append(B_.getClass().getName());
sb.append(" < ").append(nextZeroRow_).append(" x ").append(tmpColDim).append(" >");
// First element
- sb.append("\n{ { ").append(mtx.getElement(0, 0));
+ sb.append("\n{ { ").append(String.format("%.3f", B_.getElement(0, 0)));
// Rest of the first row
for (int j = 1; j < tmpColDim; j++) {
- sb.append(",\t").append(mtx.getElement(0, j));
+ sb.append(",\t").append(String.format("%.3f", B_.getElement(0, j)));
}
// For each of the remaining rows
for (int i = 1; i < nextZeroRow_; i++) {
// First column
- sb.append(" },\n{ ").append(mtx.getElement(i, 0));
+ sb.append(" },\n{ ").append(String.format("%.3f", B_.getElement(i, 0)));
// Remaining columns
for (int j = 1; j < tmpColDim; j++) {
- sb.append(",\t").append(mtx.getElement(i, j));
+ sb.append(",\t").append(String.format("%.3f", B_.getElement(i, j)));
}
}
@@ -502,34 +498,12 @@
double getSvAdjustment() { return svAdjustment_; }
private void reduceRank() {
- final SingularValue<Double> svd = SingularValue.make(B_);
- svd.compute(B_);
- svd.getSingularValues(sv_);
-
- if (sv_.length >= k_) {
- double medianSVSq = sv_[k_ - 1]; // (l_/2)th item, not yet squared
- medianSVSq *= medianSVSq;
- svAdjustment_ += medianSVSq; // always track, even if not using compensative mode
- for (int i = 0; i < k_ - 1; ++i) {
- final double val = sv_[i];
- final double adjSqSV = val * val - medianSVSq;
- S_.set(i, i, adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV));
- }
- for (int i = k_ - 1; i < S_.countColumns(); ++i) {
- S_.set(i, i, 0.0);
- }
- nextZeroRow_ = k_;
- } else {
- for (int i = 0; i < sv_.length; ++i) {
- S_.set(i, i, sv_[i]);
- }
- for (int i = sv_.length; i < S_.countColumns(); ++i) {
- S_.set(i, i, 0.0);
- }
- nextZeroRow_ = sv_.length;
- throw new RuntimeException("Running with d < 2k not yet supported");
+ if (svd_ == null) {
+ svd_ = MatrixOps.newInstance(B_, algo_, k_);
}
- S_.multiply(svd.getQ2().transpose()).supplyTo(B_);
+ final double newSvAdjustment = svd_.reduceRank(B_);
+ svAdjustment_ += newSvAdjustment;
+ nextZeroRow_ = (int) Math.min(k_ - 1, n_);
}
}
diff --git a/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOps.java b/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOps.java
new file mode 100644
index 0000000..078f4b9
--- /dev/null
+++ b/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOps.java
@@ -0,0 +1,147 @@
+/* Directly derived from LGPL'd Matrix Toolkit for Java:
+ * https://github.com/fommil/matrix-toolkits-java/blob/master/src/main/java/no/uib/cipr/matrix/SVD.java
+ */
+
+package com.yahoo.sketches.vector.decomposition;
+
+import com.yahoo.sketches.vector.matrix.Matrix;
+
+/**
+ * Computes singular value decompositions and related Matrix operations needed by Frequent Directions. May
+ * return as many singular values as exist, but other operations will limit output to k dimensions.
+ */
+abstract class MatrixOps {
+
+ // iterations for SISVD
+ private static final int DEFAULT_NUM_ITER = 8;
+
+ /**
+ * Matrix dimensions
+ */
+ final int n_; // rows
+ final int d_; // columns
+
+ /**
+ * Target number of dimensions
+ */
+ final int k_;
+
+ /**
+ * Singular value decomposition method to use
+ */
+ final SVDAlgo algo_;
+
+ int numSISVDIter_;
+
+ /**
+ * Creates an empty MatrixOps object to support Frequent Directions matrix operations
+ *
+ * @param A Matrix of the required type and correct dimensions
+ * @param algo Enum indicating method to use for SVD
+ * @param k Target number of dimensions for results
+ * @return an empty MatrixOps object
+ */
+ public static MatrixOps newInstance(final Matrix A, final SVDAlgo algo, final int k) {
+ final int n = (int) A.getNumRows();
+ final int d = (int) A.getNumColumns();
+
+ final MatrixOps mo;
+
+ switch (A.getMatrixType()) {
+ case OJALGO:
+ mo = new MatrixOpsImplOjAlgo(n, d, algo, k);
+ break;
+
+ case MTJ:
+ mo = new MatrixOpsImplMTJ(n, d, algo, k);
+ break;
+
+ default:
+ throw new IllegalArgumentException("Unknown MatrixType: " + A.getMatrixType().toString());
+ }
+
+ if (algo == SVDAlgo.SISVD) {
+ mo.setNumSISVDIter((int) Math.ceil(Math.log(d)));
+ }
+ return mo;
+ }
+
+ MatrixOps(final int n, final int d, final SVDAlgo algo, final int k) {
+ // TODO: make these actual checks
+ assert n > 0;
+ assert d > 0;
+ assert n < d;
+ assert k > 0;
+ assert k < n;
+
+ n_ = n;
+ d_ = d;
+ algo_ = algo;
+ k_ = k;
+
+ numSISVDIter_ = DEFAULT_NUM_ITER;
+ }
+
+ /**
+ * Computes and returns the singular values, in descending order. May modify the internal state
+ * of this object.
+ * @param A Matrix to decompose
+ * @return Array of singular values
+ */
+ public double[] getSingularValues(final Matrix A) {
+ svd(A, false);
+ return getSingularValues();
+ }
+
+ /**
+ * Returns pre-computed singular values (stored in descending order). Does not perform new computation.
+ * @return Singular values from the last computation
+ */
+ abstract double[] getSingularValues();
+
+ /**
+ * Computes and returns the right singular vectors of A. May modify the internal state of this object.
+ * @param A Matrix to decompose
+ * @return Matrix of size d x k
+ */
+ public Matrix getVt(final Matrix A) {
+ svd(A, true);
+ return getVt();
+ }
+
+ /**
+ * Returns pre-computed right singular vectors (row-wise?). Does not perform new computation.
+ *
+ * @return Matrix of size d x k
+ */
+ abstract Matrix getVt();
+
+ /**
+ * Performs a Frequent Directions rank reduction with the SVDAlgo used when obtaining the instance.
+ * Modifies internal state, with results queried via getVt() and getSingularValues().
+ * @return The amount of weight subtracted from the singular values
+ */
+ abstract double reduceRank(final Matrix A);
+
+ /**
+ * Returns Matrix object reconstructed using the provided singular value adjustment. Requires first
+ * decomposing the matrix.
+ * @param A Matrix to decompose and adjust
+ * @param adjustment Amount by which to adjust the singular values
+ * @return A new Matrix based on A with singular values adjusted by adjustment
+ */
+ abstract Matrix applyAdjustment(final Matrix A, final double adjustment);
+
+ /**
+ * Computes a singular value decomposition of the provided Matrix.
+ *
+ * @param A Matrix to decompose. Size must conform, and it may be overwritten on return. Pass a copy to
+ * avoid this.
+ * @param computeVectors True to compute Vt, false if only need singular values/
+ */
+ abstract void svd(final Matrix A, final boolean computeVectors);
+
+ void setNumSISVDIter(final int numSISVDIter) {
+ numSISVDIter_ = numSISVDIter;
+ }
+}
diff --git a/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOpsImplMTJ.java b/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOpsImplMTJ.java
new file mode 100644
index 0000000..eebf1f9
--- /dev/null
+++ b/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOpsImplMTJ.java
@@ -0,0 +1,333 @@
+/* Portions derived from LGPL'd Matrix Toolkit for Java:
+ * https://github.com/fommil/matrix-toolkits-java/blob/master/src/main/java/no/uib/cipr/matrix/SVD.java
+ */
+
+package com.yahoo.sketches.vector.decomposition;
+
+import java.util.concurrent.ThreadLocalRandom;
+
+import org.netlib.util.intW;
+
+import com.github.fommil.netlib.BLAS;
+import com.github.fommil.netlib.LAPACK;
+import com.yahoo.sketches.vector.matrix.Matrix;
+import com.yahoo.sketches.vector.matrix.MatrixImplMTJ;
+import com.yahoo.sketches.vector.matrix.MatrixType;
+import no.uib.cipr.matrix.DenseMatrix;
+import no.uib.cipr.matrix.MatrixEntry;
+import no.uib.cipr.matrix.NotConvergedException;
+import no.uib.cipr.matrix.QR;
+import no.uib.cipr.matrix.SVD;
+import no.uib.cipr.matrix.SymmDenseEVD;
+import no.uib.cipr.matrix.UpperSymmDenseMatrix;
+import no.uib.cipr.matrix.sparse.CompDiagMatrix;
+import no.uib.cipr.matrix.sparse.LinkedSparseMatrix;
+
+/**
+ * Computes singular value decompositions
+ */
+class MatrixOpsImplMTJ extends MatrixOps {
+
+ /**
+ * The singular values
+ */
+ private final double[] sv_;
+
+ /**
+ * Singular vectors, sparse version of singular value matrix
+ */
+ private DenseMatrix Vt_;
+ private CompDiagMatrix S_;
+
+ /**
+ * Work arrays for full SVD
+ */
+ private double[] work_;
+ private int[] iwork_;
+
+ /**
+ * Work arrays for SISVD
+ */
+ private DenseMatrix block_;
+ private DenseMatrix T_;
+
+ /**
+ * Work objects for SymmEVD
+ */
+ private SymmDenseEVD evd_;
+ private LinkedSparseMatrix rotS_;
+
+ /**
+ * Creates an empty MatrixOps
+ *
+ * @param n Number of rows in matrix
+ * @param d Number of columns in matrix
+ * @param algo SVD algorithm to apply
+ * @param k Target number of dimensions for any reduction operations
+ */
+ //MatrixOpsImplMTJ(final MatrixImplMTJ A, final SVDAlgo algo, final int k) {
+ MatrixOpsImplMTJ(final int n, final int d, final SVDAlgo algo, final int k) {
+ super(n, d, algo, k);
+
+ // Allocate space for the decomposition
+ sv_ = new double[Math.min(n_, d_)];
+ Vt_ = null; // lazy allocation
+ }
+
+ @Override
+ void svd(final Matrix A, final boolean computeVectors) {
+ assert A.getMatrixType() == MatrixType.MTJ;
+
+ if (A.getNumRows() != n_) {
+ throw new IllegalArgumentException("A.numRows() != n_");
+ } else if (A.getNumColumns() != d_) {
+ throw new IllegalArgumentException("A.numColumns() != d_");
+ }
+
+ if (computeVectors && Vt_ == null) {
+ Vt_ = new DenseMatrix(n_, d_);
+
+ final int[] diag = {0}; // only need the main diagonal
+ S_ = new CompDiagMatrix(n_, n_, diag);
+ }
+
+ switch (algo_) {
+ case FULL:
+ // make a copy if not computing vectors to avoid changing the data
+ final DenseMatrix mtx = computeVectors ? (DenseMatrix) A.getRawObject()
+ : new DenseMatrix((DenseMatrix) A.getRawObject());
+ computeFullSVD(mtx, computeVectors);
+ return;
+
+ case SISVD:
+ computeSISVD((DenseMatrix) A.getRawObject(), computeVectors);
+ return;
+
+ case SYM:
+ computeSymmEigSVD((DenseMatrix) A.getRawObject(), computeVectors);
+ return;
+
+ default:
+ throw new RuntimeException("SVDAlgo type not (yet?) supported: " + algo_.toString());
+ }
+ }
+
+ // Because exact SVD destroys A, need to reconstruct it for MTJ
+ @Override
+ public double[] getSingularValues(final Matrix A) {
+ svd(A, false);
+ return getSingularValues();
+ }
+
+ @Override
+ public double[] getSingularValues() {
+ return sv_;
+ }
+
+ @Override
+ Matrix getVt() {
+ return MatrixImplMTJ.wrap(Vt_);
+ }
+
+ @Override
+ double reduceRank(final Matrix A) {
+ svd(A, true);
+
+ double svAdjustment = 0.0;
+ S_.zero();
+
+ if (sv_.length >= k_) {
+ double medianSVSq = sv_[k_ - 1]; // (l_/2)th item, not yet squared
+ medianSVSq *= medianSVSq;
+ svAdjustment += medianSVSq; // always track, even if not using compensative mode
+ for (int i = 0; i < k_ - 1; ++i) {
+ final double val = sv_[i];
+ final double adjSqSV = val * val - medianSVSq;
+ S_.set(i, i, adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV));
+ }
+ for (int i = k_ - 1; i < S_.numColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+ //nextZeroRow_ = k_;
+ } else {
+ for (int i = 0; i < sv_.length; ++i) {
+ S_.set(i, i, sv_[i]);
+ }
+ for (int i = sv_.length; i < S_.numColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+ //nextZeroRow_ = sv_.length;
+ throw new RuntimeException("Running with d < 2k not yet supported");
+ }
+
+ // store the result back in A
+ S_.mult(Vt_, (DenseMatrix) A.getRawObject());
+
+ return svAdjustment;
+ }
+
+ @Override
+ Matrix applyAdjustment(final Matrix A, final double svAdjustment) {
+ // copy A before decomposing
+ final DenseMatrix result = new DenseMatrix((DenseMatrix) A.getRawObject(), true);
+ svd(Matrix.wrap(result), true);
+
+ for (int i = 0; i < k_ - 1; ++i) {
+ final double val = sv_[i];
+ final double adjSV = Math.sqrt(val * val + svAdjustment);
+ S_.set(i, i, adjSV);
+ }
+ for (int i = k_ - 1; i < S_.numColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+
+ S_.mult(Vt_, result);
+
+ return Matrix.wrap(result);
+ }
+
+ private void allocateSpaceFullSVD(final boolean vectors) {
+ // Find workspace requirements
+ iwork_ = new int[8 * Math.min(n_, d_)];
+
+ // Query optimal workspace
+ final double[] workSize = new double[1];
+ final intW info = new intW(0);
+ LAPACK.getInstance().dgesdd("S", n_, d_, new double[0],
+ n_, new double[0], new double[0], n_,
+ new double[0], n_, workSize, -1, iwork_, info);
+
+ // Allocate workspace
+ int lwork;
+ if (info.val != 0) {
+ if (vectors) {
+ lwork = 3
+ * Math.min(n_, d_)
+ * Math.min(n_, d_)
+ + Math.max(
+ Math.max(n_, d_),
+ 4 * Math.min(n_, d_) * Math.min(n_, d_) + 4
+ * Math.min(n_, d_));
+ } else {
+ lwork = 3
+ * Math.min(n_, d_)
+ * Math.min(n_, d_)
+ + Math.max(
+ Math.max(n_, d_),
+ 5 * Math.min(n_, d_) * Math.min(n_, d_) + 4
+ * Math.min(n_, d_));
+ }
+ } else {
+ lwork = (int) workSize[0];
+ }
+
+ lwork = Math.max(lwork, 1);
+ work_ = new double[lwork];
+ }
+
+ private void allocateSpaceSISVD() {
+ block_ = new DenseMatrix(d_, k_);
+ T_ = new DenseMatrix(n_, k_);
+ // TODO: should allocate space for QR and final SVD here?
+ }
+
+ private void allocateSpaceSymmEigSVD() {
+ T_ = new DenseMatrix(n_, n_);
+ rotS_ = new LinkedSparseMatrix(n_, n_); // only need if computing vectors, but only O(n_) size
+ evd_ = new SymmDenseEVD(n_, true, true);
+ }
+
+ private void computeFullSVD(final DenseMatrix A, final boolean computeVectors) {
+ if (work_ == null) {
+ allocateSpaceFullSVD(computeVectors);
+ }
+
+ final intW info = new intW(0);
+ final String jobType = computeVectors ? "S" : "N";
+ LAPACK.getInstance().dgesdd(jobType, n_, d_, A.getData(),
+ n_, sv_, new double[0],
+ n_, computeVectors ? Vt_.getData() : new double[0],
+ n_, work_, work_.length, iwork_, info);
+
+ if (info.val > 0) {
+ throw new RuntimeException("Did not converge after a maximum number of iterations");
+ } else if (info.val < 0) {
+ throw new IllegalArgumentException();
+ }
+ }
+
+ private void computeSISVD(final DenseMatrix A, final boolean computeVectors) {
+ if (block_ == null) {
+ allocateSpaceSISVD();
+ }
+
+ // want block_ filled as ~Normal(0,1))
+ final ThreadLocalRandom rand = ThreadLocalRandom.current();
+ for (MatrixEntry entry : block_) {
+ entry.set(rand.nextGaussian());
+ }
+ // TODO: in-line QR with direct LAPACK call
+ final QR qr = new QR(block_.numRows(), block_.numColumns());
+ block_ = qr.factor(block_).getQ(); // important for numeric stability
+
+ for (int i = 0; i < numSISVDIter_; ++i) {
+ A.mult(block_, T_);
+ A.transAmult(T_, block_);
+ block_ = qr.factor(block_).getQ(); // again, for stability
+ }
+
+ // Rayleigh-Ritz postprocessing
+ A.mult(block_, T_);
+
+ // TODO: use LAPACK directly
+ final SVD svd = new SVD(T_.numRows(), T_.numColumns(), computeVectors);
+ try {
+ svd.factor(T_);
+ } catch (final NotConvergedException e) {
+ throw new RuntimeException(e.getMessage());
+ }
+ System.arraycopy(svd.getS(), 0, sv_, 0, svd.getS().length); // sv_ is final
+
+ if (computeVectors) {
+ // V^T = (block * V^T)^T = (V^T)^T * block^T
+ // using BLAS directly since Vt is (n_ x d_) but result here is only (k_ x d_)
+ BLAS.getInstance().dgemm("T", "T", k_, d_, k_,
+ 1.0, svd.getVt().getData(), k_, block_.getData(), d_,
+ 0.0, Vt_.getData(), n_);
+ }
+ }
+
+ private void computeSymmEigSVD(final DenseMatrix A, final boolean computeVectors) {
+ if (T_ == null) {
+ allocateSpaceSymmEigSVD();
+ }
+
+ // want left singular vectors U, aka eigenvectors of AA^T -- so compute that
+ A.transBmult(A, T_);
+ try {
+ // TODO: direct LAPACK call lets us get only the top k values/vectors rather than all
+ evd_.factor(new UpperSymmDenseMatrix(T_, false));
+ } catch (final NotConvergedException e) {
+ throw new RuntimeException(e.getMessage());
+ }
+
+ // TODO: can we only use k_ values?
+ // EVD gives values low-to-high; SVD does high-to-low and we want that order. Reverse
+ // the list when extracting SVs from eigenvalues, and generate a diagonal rotation matrix
+ // to save on an extra matrix multiply if we need to compute vectors later.
+ final double[] ev = evd_.getEigenvalues();
+ for (int i = 0; i < ev.length; ++i) {
+ final double val = Math.sqrt(ev[i]);
+ sv_[n_ - i - 1] = val;
+ if (val > 0) {
+ rotS_.set(n_ - i - 1, i, 1 / val);
+ }
+ }
+
+ if (computeVectors) {
+ rotS_.transBmult(evd_.getEigenvectors(), T_);
+ T_.mult(A, Vt_);
+ }
+ }
+
+}
diff --git a/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOpsImplOjAlgo.java b/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOpsImplOjAlgo.java
new file mode 100644
index 0000000..b5b522f
--- /dev/null
+++ b/src/main/java/com/yahoo/sketches/vector/decomposition/MatrixOpsImplOjAlgo.java
@@ -0,0 +1,213 @@
+package com.yahoo.sketches.vector.decomposition;
+
+import java.util.Optional;
+
+import org.ojalgo.matrix.decomposition.Eigenvalue;
+import org.ojalgo.matrix.decomposition.QR;
+import org.ojalgo.matrix.decomposition.SingularValue;
+import org.ojalgo.matrix.store.MatrixStore;
+import org.ojalgo.matrix.store.PrimitiveDenseStore;
+import org.ojalgo.matrix.store.SparseStore;
+import org.ojalgo.random.Normal;
+
+import com.yahoo.sketches.vector.matrix.Matrix;
+import com.yahoo.sketches.vector.matrix.MatrixImplOjAlgo;
+import com.yahoo.sketches.vector.matrix.MatrixType;
+
+class MatrixOpsImplOjAlgo extends MatrixOps {
+ private double[] sv_;
+ private PrimitiveDenseStore Vt_;
+
+ // work objects for SISVD
+ private PrimitiveDenseStore block_;
+ private PrimitiveDenseStore T_; // also used in SymmetricEVD
+ private QR<Double> qr_;
+
+ // work objects for Symmetric EVD
+ private Eigenvalue<Double> evd_;
+
+
+ transient private SparseStore<Double> S_; // to hold singular value matrix
+
+ MatrixOpsImplOjAlgo(final int n, final int d, final SVDAlgo algo, final int k) {
+ super(n, d, algo, k);
+
+ // Allocate space for the decomposition
+ sv_ = new double[Math.min(n_, d_)];
+ Vt_ = null; // lazy allocation
+ }
+
+ @Override
+ void svd(final Matrix A, final boolean computeVectors) {
+ assert A.getMatrixType() == MatrixType.OJALGO;
+
+ if (A.getNumRows() != n_) {
+ throw new IllegalArgumentException("A.numRows() != n_");
+ } else if (A.getNumColumns() != d_) {
+ throw new IllegalArgumentException("A.numColumns() != d_");
+ }
+
+ if (computeVectors && Vt_ == null) {
+ Vt_ = PrimitiveDenseStore.FACTORY.makeZero(n_, d_);
+ S_ = SparseStore.makePrimitive(sv_.length, sv_.length);
+ }
+
+ switch (algo_) {
+ case FULL:
+ computeFullSVD((PrimitiveDenseStore) A.getRawObject(), computeVectors);
+ return;
+
+ case SISVD:
+ computeSISVD((PrimitiveDenseStore) A.getRawObject(), computeVectors);
+ return;
+
+ case SYM:
+ computeSymmEigSVD((PrimitiveDenseStore) A.getRawObject(), computeVectors);
+ return;
+
+ default:
+ throw new RuntimeException("SVDAlgo type not (yet?) supported: " + algo_.toString());
+ }
+ }
+
+ @Override
+ double[] getSingularValues() {
+ return sv_;
+ }
+
+ @Override
+ Matrix getVt() {
+ return MatrixImplOjAlgo.wrap(Vt_);
+ }
+
+ @Override
+ double reduceRank(final Matrix A) {
+ svd(A, true);
+
+ double svAdjustment = 0.0;
+
+ if (sv_.length >= k_) {
+ double medianSVSq = sv_[k_ - 1]; // (l_/2)th item, not yet squared
+ medianSVSq *= medianSVSq;
+ svAdjustment += medianSVSq; // always track, even if not using compensative mode
+ for (int i = 0; i < k_ - 1; ++i) {
+ final double val = sv_[i];
+ final double adjSqSV = val * val - medianSVSq;
+ S_.set(i, i, adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV)); // just to be safe
+ }
+ for (int i = k_ - 1; i < S_.countColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+ } else {
+ throw new RuntimeException("Running with d < 2k not (yet?) supported");
+ /*
+ for (int i = 0; i < sv_.length; ++i) {
+ S_.set(i, i, sv_[i]);
+ }
+ for (int i = sv_.length; i < S_.countColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+ */
+ }
+
+ // store the result back in A
+ S_.multiply(Vt_).supplyTo((PrimitiveDenseStore) A.getRawObject());
+
+ return svAdjustment;
+ }
+
+ @Override
+ Matrix applyAdjustment(final Matrix A, final double svAdjustment) {
+ // copy A before decomposing
+ final PrimitiveDenseStore result
+ = PrimitiveDenseStore.FACTORY.copy((PrimitiveDenseStore) A.getRawObject());
+ svd(Matrix.wrap(result), true);
+
+ for (int i = 0; i < k_ - 1; ++i) {
+ final double val = sv_[i];
+ final double adjSV = Math.sqrt(val * val + svAdjustment);
+ S_.set(i, i, adjSV);
+ }
+ for (int i = k_ - 1; i < S_.countColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+
+ S_.multiply(Vt_).supplyTo(result);
+
+ return Matrix.wrap(result);
+ }
+
+ private void computeFullSVD(final MatrixStore<Double> A, final boolean computeVectors) {
+ final SingularValue<Double> svd = SingularValue.make(A);
+ svd.compute(A);
+
+ svd.getSingularValues(sv_);
+
+ if (computeVectors) {
+ svd.getQ2().transpose().supplyTo(Vt_);
+ }
+ }
+
+ private void computeSISVD(final MatrixStore<Double> A, final boolean computeVectors) {
+ // want to iterate on smaller dimension of A (n x d)
+ // currently, error in constructor if d < n, so n is always the smaller dimension
+ if (block_ == null) {
+ block_ = PrimitiveDenseStore.FACTORY.makeFilled(d_, k_, new Normal(0.0, 1.0));
+ qr_ = QR.PRIMITIVE.make(block_);
+ T_ = PrimitiveDenseStore.FACTORY.makeZero(n_, k_);
+ } else {
+ block_.fillAll(new Normal(0.0, 1.0));
+ }
+
+ // orthogonalize for numeric stability
+ qr_.decompose(block_);
+ qr_.getQ().supplyTo(block_);
+
+ for (int i = 0; i < numSISVDIter_; ++i) {
+ A.multiply(block_).supplyTo(T_);
+ A.transpose().multiply(T_).supplyTo(block_);
+
+ // again, just for stability
+ qr_.decompose(block_);
+ qr_.getQ().supplyTo(block_);
+ }
+
+ // Rayleigh-Ritz postprocessing
+ A.multiply(block_).supplyTo(T_);
+
+ final SingularValue<Double> svd = SingularValue.make(T_);
+ svd.compute(T_);
+
+ svd.getSingularValues(sv_);
+
+ if (computeVectors) {
+ // V = block * Q2^T so V^T = Q2 * block^T
+ // and ojAlgo figures out that it only needs to fill the first k_ rows of Vt_
+ svd.getQ2().multiply(block_.transpose()).supplyTo(Vt_);
+ }
+ }
+
+ private void computeSymmEigSVD(final MatrixStore<Double> A, final boolean computeVectors) {
+ if (T_ == null) {
+ T_ = PrimitiveDenseStore.FACTORY.makeZero(n_, n_);
+ evd_ = Eigenvalue.PRIMITIVE.make(n_, true);
+ }
+
+ // want left singular vectors U, aka eigenvectors of AA^T -- so compute that
+ A.multiply(A.transpose()).supplyTo(T_);
+ evd_.decompose(T_);
+
+ // TODO: can we only use k_ values?
+ final double[] ev = new double[n_];
+ evd_.getEigenvalues(ev, Optional.empty());
+ for (int i = 0; i < ev.length; ++i) {
+ final double val = Math.sqrt(ev[i]);
+ sv_[i] = val;
+ if (computeVectors && val > 0) { S_.set(i, i, 1 / val); }
+ }
+
+ if (computeVectors) {
+ S_.multiply(evd_.getV().transpose()).multiply(A).supplyTo(Vt_);
+ }
+ }
+}
diff --git a/src/main/java/com/yahoo/sketches/vector/decomposition/PreambleUtil.java b/src/main/java/com/yahoo/sketches/vector/decomposition/PreambleUtil.java
index f177067..0de11d3 100644
--- a/src/main/java/com/yahoo/sketches/vector/decomposition/PreambleUtil.java
+++ b/src/main/java/com/yahoo/sketches/vector/decomposition/PreambleUtil.java
@@ -45,7 +45,7 @@
/**
* The java line separator character as a String.
*/
- public static final String LS = System.getProperty("line.separator");
+ private static final String LS = System.getProperty("line.separator");
private PreambleUtil() {}
@@ -63,7 +63,6 @@
// flag bit masks
static final int EMPTY_FLAG_MASK = 4;
- static final int COMPENSATIVE_FLAG_MASK = 128;
// Other constants
static final int SER_VER = 1;
diff --git a/src/main/java/com/yahoo/sketches/vector/decomposition/SVDAlgo.java b/src/main/java/com/yahoo/sketches/vector/decomposition/SVDAlgo.java
new file mode 100644
index 0000000..36b5462
--- /dev/null
+++ b/src/main/java/com/yahoo/sketches/vector/decomposition/SVDAlgo.java
@@ -0,0 +1,33 @@
+package com.yahoo.sketches.vector.decomposition;
+
+/**
+ * This class allows a choice of algorithms for Singular Value Decomposition. The options are:
+ * <ul>
+ * <li>FULL: The matrix library's default SVD implementation.</li>
+ * <li>SISVD: Simultaneous iteration, an approximate method likely to be more efficient only with sparse
+ * matrices or when <em>k</em> is significantly smaller than the number of rows in the sketch.</li>
+ * <li>SYM: Takes advantage of matrix dimensionality, first computing eigenvalues of AA^T, then computes
+ * intended results. Squaring A alters condition number and may cause numeric stability issues,
+ * but unlikely an issue for Frequent Directions since discarding the smaller singular values/vectors.</li>
+ * </ul>
+ */
+public enum SVDAlgo {
+ FULL(1, "Full"),
+ SISVD(2, "SISVD"),
+ SYM(3, "Symmetrized");
+
+ private int id_;
+ private String name_;
+
+ SVDAlgo(final int id, final String name) {
+ id_ = id;
+ name_ = name;
+ }
+
+ public int getId() { return id_; }
+
+ public String getName() { return name_; }
+
+ @Override
+ public String toString() { return name_; }
+}
diff --git a/src/main/java/com/yahoo/sketches/vector/matrix/Matrix.java b/src/main/java/com/yahoo/sketches/vector/matrix/Matrix.java
index 533e420..d6b6de2 100644
--- a/src/main/java/com/yahoo/sketches/vector/matrix/Matrix.java
+++ b/src/main/java/com/yahoo/sketches/vector/matrix/Matrix.java
@@ -12,7 +12,7 @@
import com.yahoo.memory.Memory;
import com.yahoo.sketches.vector.MatrixFamily;
-
+import no.uib.cipr.matrix.DenseMatrix;
/**
* Provides an implementation-agnostic wrapper around Matrix classes.
@@ -32,10 +32,12 @@
* @param type Matrix implementation type to use
* @return The heapified matrix
*/
- public static Matrix heapify(final Memory srcMem, final MatrixBuilder.Algo type) {
+ public static Matrix heapify(final Memory srcMem, final MatrixType type) {
switch (type) {
case OJALGO:
return MatrixImplOjAlgo.heapifyInstance(srcMem);
+ case MTJ:
+ return MatrixImplMTJ.heapifyInstance(srcMem);
default:
return null;
}
@@ -52,6 +54,8 @@
return null;
} else if (mtx instanceof PrimitiveDenseStore) {
return MatrixImplOjAlgo.wrap((PrimitiveDenseStore) mtx);
+ } else if (mtx instanceof DenseMatrix) {
+ return MatrixImplMTJ.wrap((DenseMatrix) mtx);
}
else {
throw new IllegalArgumentException("wrap() does not currently support "
@@ -156,7 +160,7 @@
}
/**
- * Gets serialized size of the Matrix in cmpact form, in bytes.
+ * Gets serialized size of the Matrix in compact form, in bytes.
* @param rows Number of rows to select for writing
* @param cols Number of columns to select for writing
* @return Number of bytes needed to serialize the first (rows, cols) of this Matrix
@@ -213,4 +217,6 @@
return sb.toString();
}
+
+ public abstract MatrixType getMatrixType();
}
diff --git a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixBuilder.java b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixBuilder.java
index ba30844..73f0751 100644
--- a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixBuilder.java
+++ b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixBuilder.java
@@ -10,27 +10,8 @@
* Provides a builder for Matrix objects.
*/
public class MatrixBuilder {
- public enum Algo {
- OJALGO(1, "ojAlgo"),
- NATIVE(2, "native");
- private int id_;
- private String name_;
-
- Algo(final int id, final String name) {
- id_ = id;
- name_ = name;
- }
-
- public int getId() { return id_; }
-
- public String getName() { return name_; }
-
- @Override
- public String toString() { return name_; }
- }
-
- private Algo type_ = Algo.OJALGO; // default type
+ private MatrixType type_ = MatrixType.OJALGO; // default type
public MatrixBuilder() {}
@@ -39,16 +20,16 @@
* @param type One of the supported types
* @return This MatrixBuilder object
*/
- public MatrixBuilder setType(final Algo type) {
+ public MatrixBuilder setType(final MatrixType type) {
type_ = type;
return this;
}
/**
* Returns a value from an enum defining the type of object backing any Matrix objects created.
- * @return An item from the Algo enum.
+ * @return An item from the MatrixType enum.
*/
- public Algo getBackingType() {
+ public MatrixType getBackingType() {
return type_;
}
@@ -64,9 +45,11 @@
case OJALGO:
return MatrixImplOjAlgo.newInstance(numRows, numCols);
- case NATIVE:
+ case MTJ:
+ return MatrixImplMTJ.newInstance(numRows, numCols);
+
default:
- throw new IllegalArgumentException("Only Algo.OJALGO is currently supported Matrix type");
+ throw new IllegalArgumentException("OJALGO and MTJ are currently the only supported MatrixTypes");
}
}
}
diff --git a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplMTJ.java b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplMTJ.java
new file mode 100644
index 0000000..5834b1a
--- /dev/null
+++ b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplMTJ.java
@@ -0,0 +1,226 @@
+/*
+ * Copyright 2017, Yahoo! Inc.
+ * Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root
+ * for terms.
+ */
+
+package com.yahoo.sketches.vector.matrix;
+
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.COMPACT_FLAG_MASK;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractFamilyID;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractFlags;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractNumColumns;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractNumColumnsUsed;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractNumRows;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractNumRowsUsed;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractPreLongs;
+import static com.yahoo.sketches.vector.matrix.MatrixPreambleUtil.extractSerVer;
+
+import com.yahoo.memory.Memory;
+import com.yahoo.memory.WritableMemory;
+import com.yahoo.sketches.vector.MatrixFamily;
+import no.uib.cipr.matrix.DenseMatrix;
+
+public final class MatrixImplMTJ extends Matrix {
+ private DenseMatrix mtx_;
+
+ private MatrixImplMTJ(final int numRows, final int numCols) {
+ mtx_ = new DenseMatrix(numRows, numCols);
+ numRows_ = numRows;
+ numCols_ = numCols;
+ }
+
+ private MatrixImplMTJ(final DenseMatrix mtx) {
+ mtx_ = mtx;
+ numRows_ = mtx.numRows();
+ numCols_ = mtx.numColumns();
+ }
+
+ static Matrix newInstance(final int numRows, final int numCols) {
+ return new MatrixImplMTJ(numRows, numCols);
+ }
+
+ static Matrix heapifyInstance(final Memory srcMem) {
+ final int minBytes = MatrixFamily.MATRIX.getMinPreLongs() * Long.BYTES;
+ final long memCapBytes = srcMem.getCapacity();
+ if (memCapBytes < minBytes) {
+ throw new IllegalArgumentException("Source Memory too small: " + memCapBytes
+ + " < " + minBytes);
+ }
+
+ final int preLongs = extractPreLongs(srcMem);
+ final int serVer = extractSerVer(srcMem);
+ final int familyID = extractFamilyID(srcMem);
+
+ if (serVer != 1) {
+ throw new IllegalArgumentException("Invalid SerVer reading srcMem. Expected 1, found: "
+ + serVer);
+ }
+ if (familyID != MatrixFamily.MATRIX.getID()) {
+ throw new IllegalArgumentException("srcMem does not point to a Matrix");
+ }
+
+ final int flags = extractFlags(srcMem);
+ final boolean isCompact = (flags & COMPACT_FLAG_MASK) > 0;
+
+ int nRows = extractNumRows(srcMem);
+ int nCols = extractNumColumns(srcMem);
+
+ final MatrixImplMTJ matrix;
+
+ if (isCompact) {
+ matrix = new MatrixImplMTJ(nRows, nCols);
+
+ nRows = extractNumRowsUsed(srcMem);
+ nCols = extractNumColumnsUsed(srcMem);
+
+ int memOffset = preLongs * Long.BYTES;
+ for (int c = 0; c < nCols; ++c) {
+ for (int r = 0; r < nRows; ++r) {
+ matrix.mtx_.set(r, c, srcMem.getDouble(memOffset));
+ memOffset += Double.BYTES;
+ }
+ }
+ } else {
+ final int nElements = nRows * nCols;
+ final double[] data = new double[nElements];
+ srcMem.getDoubleArray(preLongs * Long.BYTES, data, 0, nElements);
+
+ matrix = new MatrixImplMTJ(new DenseMatrix(nRows, nCols, data, false));
+ }
+
+ return matrix;
+ }
+
+ static Matrix wrap(final DenseMatrix mtx) {
+ return new MatrixImplMTJ(mtx);
+ }
+
+ @Override
+ public Object getRawObject() {
+ return mtx_;
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ final int preLongs = 2;
+ final long numElements = numRows_ * numCols_;
+ assert numElements == (mtx_.numRows() * mtx_.numColumns());
+
+ final int outBytes = (int) ((preLongs * Long.BYTES) + (numElements * Double.BYTES));
+ final byte[] outByteArr = new byte[outBytes];
+ final WritableMemory memOut = WritableMemory.wrap(outByteArr);
+ final Object memObj = memOut.getArray();
+ final long memAddr = memOut.getCumulativeOffset(0L);
+
+ MatrixPreambleUtil.insertPreLongs(memObj, memAddr, preLongs);
+ MatrixPreambleUtil.insertSerVer(memObj, memAddr, MatrixPreambleUtil.SER_VER);
+ MatrixPreambleUtil.insertFamilyID(memObj, memAddr, MatrixFamily.MATRIX.getID());
+ MatrixPreambleUtil.insertFlags(memObj, memAddr, 0);
+ MatrixPreambleUtil.insertNumRows(memObj, memAddr, numRows_);
+ MatrixPreambleUtil.insertNumColumns(memObj, memAddr, numCols_);
+ memOut.putDoubleArray(preLongs << 3, mtx_.getData(), 0, (int) numElements);
+
+ return outByteArr;
+ }
+
+ @Override
+ public byte[] toCompactByteArray(final int numRows, final int numCols) {
+ // TODO: row/col limit checks
+
+ final int preLongs = 3;
+
+ // for non-compact we can do an array copy, so save as non-compact if using the entire matrix
+ final long numElements = (long) numRows * numCols;
+ final boolean isCompact = numElements < (mtx_.numRows() * mtx_.numColumns());
+ if (!isCompact) {
+ return toByteArray();
+ }
+
+ final int outBytes = (int) ((preLongs * Long.BYTES) + (numElements * Double.BYTES));
+ final byte[] outByteArr = new byte[outBytes];
+ final WritableMemory memOut = WritableMemory.wrap(outByteArr);
+ final Object memObj = memOut.getArray();
+ final long memAddr = memOut.getCumulativeOffset(0L);
+
+ MatrixPreambleUtil.insertPreLongs(memObj, memAddr, preLongs);
+ MatrixPreambleUtil.insertSerVer(memObj, memAddr, MatrixPreambleUtil.SER_VER);
+ MatrixPreambleUtil.insertFamilyID(memObj, memAddr, MatrixFamily.MATRIX.getID());
+ MatrixPreambleUtil.insertFlags(memObj, memAddr, COMPACT_FLAG_MASK);
+ MatrixPreambleUtil.insertNumRows(memObj, memAddr, mtx_.numRows());
+ MatrixPreambleUtil.insertNumColumns(memObj, memAddr, mtx_.numColumns());
+ MatrixPreambleUtil.insertNumRowsUsed(memObj, memAddr, numRows);
+ MatrixPreambleUtil.insertNumColumnsUsed(memObj, memAddr, numCols);
+
+ // write elements in column-major order
+ long offsetBytes = preLongs * Long.BYTES;
+ for (int c = 0; c < numCols; ++c) {
+ for (int r = 0; r < numRows; ++r) {
+ memOut.putDouble(offsetBytes, mtx_.get(r, c));
+ offsetBytes += Double.BYTES;
+ }
+ }
+
+ return outByteArr;
+ }
+
+ @Override
+ public double getElement(final int row, final int col) {
+ return mtx_.get(row, col);
+ }
+
+ @Override
+ public double[] getRow(final int row) {
+ final int cols = mtx_.numColumns();
+ final double[] result = new double[cols];
+ for (int c = 0; c < cols; ++c) {
+ result[c] = mtx_.get(row, c);
+ }
+ return result;
+ }
+
+ @Override
+ public double[] getColumn(final int col) {
+ final int rows = mtx_.numRows();
+ final double[] result = new double[rows];
+ for (int r = 0; r < rows; ++r) {
+ result[r] = mtx_.get(r, col);
+ }
+ return result;
+ }
+
+ @Override
+ public void setElement(final int row, final int col, final double value) {
+ mtx_.set(row, col, value);
+ }
+
+ @Override
+ public void setRow(final int row, final double[] values) {
+ if (values.length != mtx_.numColumns()) {
+ throw new IllegalArgumentException("Invalid number of elements for row. Expected "
+ + mtx_.numColumns() + ", found " + values.length);
+ }
+
+ for (int i = 0; i < mtx_.numColumns(); ++i) {
+ mtx_.set(row, i, values[i]);
+ }
+ }
+
+ @Override
+ public void setColumn(final int column, final double[] values) {
+ if (values.length != mtx_.numRows()) {
+ throw new IllegalArgumentException("Invalid number of elements for column. Expected "
+ + mtx_.numRows() + ", found " + values.length);
+ }
+
+ for (int i = 0; i < mtx_.numRows(); ++i) {
+ // TODO: System.arraycopy()?
+ mtx_.set(i, column, values[i]);
+ }
+ }
+
+ @Override
+ public MatrixType getMatrixType() {
+ return MatrixType.MTJ;
+ }
+}
diff --git a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplOjAlgo.java b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplOjAlgo.java
index 6e4c7b9..0af7f5c 100644
--- a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplOjAlgo.java
+++ b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixImplOjAlgo.java
@@ -211,4 +211,10 @@
mtx_.set(i, column, values[i]);
}
}
+
+ @Override
+ public MatrixType getMatrixType() {
+ return MatrixType.OJALGO;
+ }
+
}
diff --git a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixPreambleUtil.java b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixPreambleUtil.java
index 1bfc907..8110b8f 100644
--- a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixPreambleUtil.java
+++ b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixPreambleUtil.java
@@ -43,7 +43,7 @@
/**
* The java line separator character as a String.
*/
- public static final String LS = System.getProperty("line.separator");
+ static final String LS = System.getProperty("line.separator");
private MatrixPreambleUtil() {}
@@ -196,7 +196,7 @@
* @param mem the given Memory
* @return the extracted prelongs value.
*/
- static int getAndCheckPreLongs(final Memory mem) {
+ private static int getAndCheckPreLongs(final Memory mem) {
final long cap = mem.getCapacity();
if (cap < Long.BYTES) { throwNotBigEnough(cap, Long.BYTES); }
final int preLongs = extractPreLongs(mem);
diff --git a/src/main/java/com/yahoo/sketches/vector/matrix/MatrixType.java b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixType.java
new file mode 100644
index 0000000..8a3dac2
--- /dev/null
+++ b/src/main/java/com/yahoo/sketches/vector/matrix/MatrixType.java
@@ -0,0 +1,21 @@
+package com.yahoo.sketches.vector.matrix;
+
+public enum MatrixType {
+ OJALGO(1, "ojAlgo"),
+ MTJ(2, "MTJ");
+
+ private int id_;
+ private String name_;
+
+ MatrixType(final int id, final String name) {
+ id_ = id;
+ name_ = name;
+ }
+
+ public int getId() { return id_; }
+
+ public String getName() { return name_; }
+
+ @Override
+ public String toString() { return name_; }
+}
diff --git a/src/test/java/com/yahoo/sketches/vector/decomposition/FrequentDirectionsTest.java b/src/test/java/com/yahoo/sketches/vector/decomposition/FrequentDirectionsTest.java
index 65e1244..d9485b3 100644
--- a/src/test/java/com/yahoo/sketches/vector/decomposition/FrequentDirectionsTest.java
+++ b/src/test/java/com/yahoo/sketches/vector/decomposition/FrequentDirectionsTest.java
@@ -61,10 +61,28 @@
}
@Test
- public void checkUpdate() {
+ public void checkSymmUpdate() {
final int k = 4;
final int d = 16; // should be > 2k
final FrequentDirections fd = FrequentDirections.newInstance(k, d);
+ fd.setSVDAlgo(SVDAlgo.SYM); // default, but force anyway
+
+ runUpdateTest(fd);
+ }
+
+ @Test
+ public void checkFullSVDUpdate() {
+ final int k = 4;
+ final int d = 16; // should be > 2k
+ final FrequentDirections fd = FrequentDirections.newInstance(k, d);
+ fd.setSVDAlgo(SVDAlgo.FULL);
+
+ runUpdateTest(fd);
+ }
+
+ private void runUpdateTest(final FrequentDirections fd) {
+ final int k = fd.getK();
+ final int d = fd.getD();
// creates matrix with increasing values along diagonal
final double[] input = new double[d];
@@ -82,17 +100,10 @@
input[(2 * k) - 1] = 0.0;
input[2 * k] = 2.0 * k;
fd.update(input); // trigger reduceRank(), then add 1 more row
- assertEquals(fd.getNumRows(), k + 1);
-
- fd.reset();
- assertTrue(fd.isEmpty());
- fd.forceReduceRank(); // should be a no-op
- assertTrue(fd.isEmpty());
-
- println(fd.toString());
- println(fd.toString(true));
+ assertEquals(fd.getNumRows(), k);
}
+
@Test
public void updateWithTooFewDimensions() {
final int k = 4;
@@ -140,22 +151,40 @@
assertEquals(fd2.getN(), initialRows);
fd1.update(fd2);
- final int expectedRows = ((2 * initialRows) % k) + k; // assumes 2 * initialRows > k
+ final int expectedRows = ((2 * initialRows) % k) + k - 1; // assumes 2 * initialRows > k
assertEquals(fd1.getNumRows(), expectedRows);
assertEquals(fd1.getN(), 2 * initialRows);
final Matrix result = fd1.getResult(false);
assertNotNull(result);
- assertEquals(result.getNumRows(), expectedRows);
+ assertEquals(result.getNumRows(), 2 * k);
println(fd1.toString(true, true, true));
}
@Test
- public void checkCompensativeResult() {
+ public void checkCompensativeResultSymSVD() {
final int k = 4;
final int d = 10; // should be > 2k
final FrequentDirections fd = FrequentDirections.newInstance(k, d);
+ fd.setSVDAlgo(SVDAlgo.SYM);
+
+ runCompensativeResultTest(fd);
+ }
+
+ @Test
+ public void checkCompensativeResultFullSVD() {
+ final int k = 4;
+ final int d = 10; // should be > 2k
+ final FrequentDirections fd = FrequentDirections.newInstance(k, d);
+ fd.setSVDAlgo(SVDAlgo.FULL);
+
+ runCompensativeResultTest(fd);
+ }
+
+ private void runCompensativeResultTest(final FrequentDirections fd) {
+ final int d = fd.getD();
+ final int k = fd.getK();
// diagonal matrix for easy checking
final double[] input = new double[d];
@@ -172,12 +201,11 @@
assertEquals(m.getElement(i,i), 1.0 * (i + 1), 1e-6);
}
- final Matrix p = fd.getProjectionMatrix();
- double[] sv = fd.getSingularValues(false);
-
// without compensation, but force rank reduction and check projection at the same time
fd.forceReduceRank();
m = fd.getResult();
+ final Matrix p = fd.getProjectionMatrix();
+ double[] sv = fd.getSingularValues(false);
for (int i = k; i > 1; --i) {
final double val = Math.abs(m.getElement(k - i, i));
final double expected = Math.sqrt(((i + 1) * (i + 1)) - fd.getSvAdjustment());
@@ -185,8 +213,8 @@
assertEquals(sv[k - i], expected, 1e-10);
assertEquals(Math.abs(p.getElement(k - i, i)), 1.0, 1e-6);
}
- assertEquals(m.getElement(k, 1), 0.0);
- assertEquals(p.getElement(k, 1), 0.0);
+ assertEquals(m.getElement(k, 1), 0.0, 0.0); // might return -0.0
+ assertEquals(p.getElement(k, 1), 0.0, 0.0); // might return -0.0
// with compensation
m = fd.getResult(true);
@@ -332,10 +360,10 @@
}
}
-/**
- * println the message
- * @param msg the message
- */
+ /**
+ * println the message
+ * @param msg the message
+ */
private void println(final String msg) {
//System.out.println(msg);
}
diff --git a/src/test/java/com/yahoo/sketches/vector/decomposition/MatrixOpsTest.java b/src/test/java/com/yahoo/sketches/vector/decomposition/MatrixOpsTest.java
new file mode 100644
index 0000000..b8ea789
--- /dev/null
+++ b/src/test/java/com/yahoo/sketches/vector/decomposition/MatrixOpsTest.java
@@ -0,0 +1,107 @@
+package com.yahoo.sketches.vector.decomposition;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.fail;
+
+import org.testng.annotations.Test;
+
+import com.yahoo.sketches.vector.matrix.Matrix;
+import com.yahoo.sketches.vector.matrix.MatrixBuilder;
+import com.yahoo.sketches.vector.matrix.MatrixType;
+
+public class MatrixOpsTest {
+
+ @Test
+ public void compareSVDAccuracy() {
+ final int d = 10;
+ final int k = 6;
+ final Matrix input = generateIncreasingEye(d, 2 * k);
+
+ final MatrixOps moFull = MatrixOps.newInstance(input, SVDAlgo.FULL, k);
+ final MatrixOps moSym = MatrixOps.newInstance(input, SVDAlgo.SYM, k);
+ final MatrixOps moSISVD = MatrixOps.newInstance(input, SVDAlgo.SISVD, k);
+ moSISVD.setNumSISVDIter(50 * k); // intentionally run many extra iterations for tighter convegence
+
+ // just singular values first
+ moFull.svd(input, false);
+ moSym.svd(input, false);
+ moSISVD.svd(input, false);
+ final double[] fullSv = moFull.getSingularValues();
+ compareSingularValues(fullSv, moSym.getSingularValues(), fullSv.length);
+ compareSingularValues(fullSv, moSISVD.getSingularValues(), k); // SISVD only produces k values
+
+ // now with vectors
+ moFull.svd(input, true);
+ moSym.svd(input, true);
+ moSISVD.svd(input, true);
+ // TODO: better comparison is vector-wise, testing that sign changes are consistent but that
+ // requires non-zero elements
+ final Matrix fullVt = moFull.getVt();
+ compareMatrixElementMagnitudes(fullVt, moSym.getVt(), (int) fullVt.getNumRows());
+ compareMatrixElementMagnitudes(fullVt, moSISVD.getVt(), k); // SISVD only produces k vectors
+
+ // just to be sure
+ compareMatrixElementMagnitudes(fullVt, moFull.getVt(input), (int) fullVt.getNumRows());
+ }
+
+ @Test
+ public void checkInvalidMatrixSize() {
+ final int d = 10;
+ final int k = 6;
+ final Matrix A = generateIncreasingEye(d, 2 * k);
+ final MatrixOps mo = MatrixOps.newInstance(A, SVDAlgo.FULL, k);
+
+ Matrix B = generateIncreasingEye(d, 2 * k + 1);
+ try {
+ mo.svd(B, true);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+
+ B = generateIncreasingEye(d - 1, 2 * k);
+ try {
+ mo.svd(B, false);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+
+ }
+
+ private void compareSingularValues(final double[] A, final double[] B, final int n) {
+ assertEquals(A.length, B.length);
+
+ for (int i = 0; i < n; ++i) {
+ assertEquals(A[i], B[i], 1e-6);
+ }
+ }
+
+
+ private void compareMatrixElementMagnitudes(final Matrix A, final Matrix B, final int n) {
+ assertEquals(A.getNumColumns(), B.getNumColumns());
+ assertEquals(A.getNumRows(), B.getNumRows());
+
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < A.getNumColumns(); ++j) {
+ assertEquals(Math.abs(A.getElement(i, j)), Math.abs(B.getElement(i, j)), 1e-6);
+ }
+ }
+ }
+
+ /**
+ * Creates a scaled I matrix, where the diagonal consists of increasing integers,
+ * starting with 1.0.
+ * @param nRows number of rows
+ * @param nCols number of columns
+ * @return PrimitiveDenseStore, suitable for direct use or wrapping
+ */
+ private static Matrix generateIncreasingEye(final int nRows, final int nCols) {
+ final Matrix m = new MatrixBuilder().setType(MatrixType.OJALGO).build(nRows, nCols);
+ for (int i = 0; (i < nRows) && (i < nCols); ++i) {
+ m.setElement(i, i, 1.0 + i);
+ }
+ return m;
+ }
+
+}
diff --git a/src/test/java/com/yahoo/sketches/vector/matrix/MatrixBuilderTest.java b/src/test/java/com/yahoo/sketches/vector/matrix/MatrixBuilderTest.java
index fd544b4..eb071cd 100644
--- a/src/test/java/com/yahoo/sketches/vector/matrix/MatrixBuilderTest.java
+++ b/src/test/java/com/yahoo/sketches/vector/matrix/MatrixBuilderTest.java
@@ -15,29 +15,29 @@
@Test
public void checkBuild() {
final MatrixBuilder builder = new MatrixBuilder();
- assertEquals(builder.getBackingType(), MatrixBuilder.Algo.OJALGO); // default type
+ assertEquals(builder.getBackingType(), MatrixType.OJALGO); // default type
- final Matrix m = builder.build(128, 512);
+ Matrix m = builder.build(128, 512);
+ assertNotNull(m);
+
+ builder.setType(MatrixType.MTJ);
+ assertEquals(builder.getBackingType(), MatrixType.MTJ);
+
+ m = builder.build(128, 512);
assertNotNull(m);
}
@Test
public void checkSetType() {
final MatrixBuilder builder = new MatrixBuilder();
- final MatrixBuilder.Algo type = builder.getBackingType();
- assertEquals(type, MatrixBuilder.Algo.OJALGO); // default type
- assertEquals(type.getId(), MatrixBuilder.Algo.OJALGO.getId());
- assertEquals(type.getName(), MatrixBuilder.Algo.OJALGO.getName());
+ final MatrixType type = builder.getBackingType();
+ assertEquals(type, MatrixType.OJALGO); // default type
+ assertEquals(type.getId(), MatrixType.OJALGO.getId());
+ assertEquals(type.getName(), MatrixType.OJALGO.getName());
- builder.setType(MatrixBuilder.Algo.NATIVE);
- assertEquals(builder.getBackingType(), MatrixBuilder.Algo.NATIVE);
- assertEquals(builder.getBackingType().toString(), "native");
-
- try {
- builder.build(10, 20);
- } catch (final IllegalArgumentException e) {
- // expected until native is implemented
- }
+ builder.setType(MatrixType.MTJ);
+ assertEquals(builder.getBackingType(), MatrixType.MTJ);
+ assertEquals(builder.getBackingType().toString(), "MTJ");
}
}
diff --git a/src/test/java/com/yahoo/sketches/vector/matrix/MatrixImplMTJTest.java b/src/test/java/com/yahoo/sketches/vector/matrix/MatrixImplMTJTest.java
new file mode 100644
index 0000000..b23b961
--- /dev/null
+++ b/src/test/java/com/yahoo/sketches/vector/matrix/MatrixImplMTJTest.java
@@ -0,0 +1,211 @@
+/*
+ * Copyright 2017, Yahoo, Inc.
+ * Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root for terms.
+ */
+
+package com.yahoo.sketches.vector.matrix;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.fail;
+
+import org.testng.annotations.Test;
+
+import com.yahoo.memory.Memory;
+import com.yahoo.memory.WritableMemory;
+import no.uib.cipr.matrix.DenseMatrix;
+
+public class MatrixImplMTJTest {
+ @Test
+ public void checkInstantiation() {
+ final int nRows = 10;
+ final int nCols = 15;
+ final Matrix m = MatrixImplMTJ.newInstance(nRows, nCols);
+ assertEquals(m.getNumRows(), nRows);
+ assertEquals(m.getNumColumns(), nCols);
+
+ final DenseMatrix pds = (DenseMatrix) m.getRawObject();
+ assertEquals(pds.numRows(), nRows);
+ assertEquals(pds.numColumns(), nCols);
+
+ final Matrix wrapped = Matrix.wrap(pds);
+ MatrixTest.checkMatrixEquality(wrapped, m);
+ assertEquals(wrapped.getRawObject(), pds);
+ }
+
+ @Test
+ public void updateAndQueryValues() {
+ final int nRows = 5;
+ final int nCols = 5;
+ final Matrix m = generateIncreasingEye(nRows, nCols); // tests setElement() in method
+
+ for (int i = 0; i < nRows; ++i) {
+ for (int j = 0; j < nCols; ++j) {
+ final double val = m.getElement(i, j);
+ if (i == j) {
+ assertEquals(val, i + 1.0);
+ } else {
+ assertEquals(val, 0.0);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void checkStandardSerialization() {
+ final int nRows = 3;
+ final int nCols = 7;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+
+ final byte[] mtxBytes = m.toByteArray();
+ assertEquals(mtxBytes.length, m.getSizeBytes());
+
+ final Memory mem = Memory.wrap(mtxBytes);
+ final Matrix tgt = MatrixImplMTJ.heapifyInstance(mem);
+ MatrixTest.checkMatrixEquality(tgt, m);
+ }
+
+ @Test
+ public void checkCompactSerialization() {
+ final int nRows = 4;
+ final int nCols = 7;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+
+ byte[] mtxBytes = m.toCompactByteArray(nRows - 1, 7);
+ assertEquals(mtxBytes.length, m.getCompactSizeBytes(nRows - 1, 7));
+
+ Memory mem = Memory.wrap(mtxBytes);
+ Matrix tgt = MatrixImplMTJ.heapifyInstance(mem);
+ for (int c = 0; c < nCols; ++c) {
+ for (int r = 0; r < (nRows - 1); ++r) {
+ assertEquals(tgt.getElement(r, c), m.getElement(r, c)); // equal here
+ }
+ // assuming nRows - 1 so check only the last row as being 0
+ assertEquals(tgt.getElement(nRows - 1, c), 0.0);
+ }
+
+ // test without compacting
+ mtxBytes = m.toCompactByteArray(nRows, nCols);
+ assertEquals(mtxBytes.length, m.getSizeBytes());
+ mem = Memory.wrap(mtxBytes);
+ tgt = MatrixImplMTJ.heapifyInstance(mem);
+ MatrixTest.checkMatrixEquality(tgt, m);
+ }
+
+ @Test
+ public void matrixRowOperations() {
+ final int nRows = 7;
+ final int nCols = 5;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+
+ final int tgtCol = 2;
+ final double[] v = m.getRow(tgtCol); // diagonal matrix, so this works ok
+ for (int i = 0; i < v.length; ++i) {
+ assertEquals(v[i], (i == tgtCol ? i + 1.0 : 0.0));
+ }
+
+ assertEquals(m.getElement(6, tgtCol), 0.0);
+ m.setRow(6, v);
+ assertEquals(m.getElement(6, tgtCol), tgtCol + 1.0);
+ }
+
+ @Test
+ public void matrixColumnOperations() {
+ final int nRows = 9;
+ final int nCols = 4;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+
+ final int tgtRow = 3;
+ final double[] v = m.getColumn(tgtRow); // diagonal matrix, so this works ok
+ for (int i = 0; i < v.length; ++i) {
+ assertEquals(v[i], (i == tgtRow ? i + 1.0 : 0.0));
+ }
+
+ assertEquals(m.getElement(tgtRow, 0), 0.0);
+ m.setColumn(0, v);
+ assertEquals(m.getElement(tgtRow, 0), tgtRow + 1.0);
+ }
+
+ @Test
+ public void invalidRowColumnOperations() {
+ final int nRows = 9;
+ final int nCols = 4;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+
+ final double[] shortRow = new double[nCols - 2];
+ try {
+ m.setRow(1, shortRow);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+
+ final double[] longColumn = new double[nRows + 2];
+ try {
+ m.setColumn(1, longColumn);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void invalidSerVer() {
+ final int nRows = 3;
+ final int nCols = 3;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+ final byte[] sketchBytes = m.toByteArray();
+ final WritableMemory mem = WritableMemory.wrap(sketchBytes);
+ MatrixPreambleUtil.insertSerVer(mem.getArray(), mem.getCumulativeOffset(0L), 0);
+
+ try {
+ MatrixImplMTJ.heapifyInstance(mem);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void invalidFamily() {
+ final int nRows = 3;
+ final int nCols = 3;
+ final Matrix m = generateIncreasingEye(nRows, nCols);
+ final byte[] sketchBytes = m.toByteArray();
+ final WritableMemory mem = WritableMemory.wrap(sketchBytes);
+ MatrixPreambleUtil.insertFamilyID(mem.getArray(), mem.getCumulativeOffset(0L), 0);
+
+ try {
+ MatrixImplMTJ.heapifyInstance(mem);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void insufficientMemoryCapacity() {
+ final byte[] bytes = new byte[6];
+ final Memory mem = Memory.wrap(bytes);
+ try {
+ MatrixImplMTJ.heapifyInstance(mem);
+ fail();
+ } catch (final IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ /**
+ * Creates a scaled I matrix, where the diagonal consists of increasing integers,
+ * starting with 1.0.
+ * @param nRows number of rows
+ * @param nCols number of columns
+ * @return PrimitiveDenseStore, suitable for direct use or wrapping
+ */
+ private static Matrix generateIncreasingEye(final int nRows, final int nCols) {
+ final Matrix m = MatrixImplMTJ.newInstance(nRows, nCols);
+ for (int i = 0; (i < nRows) && (i < nCols); ++i) {
+ m.setElement(i, i, 1.0 + i);
+ }
+ return m;
+ }
+}
diff --git a/src/test/java/com/yahoo/sketches/vector/matrix/MatrixTest.java b/src/test/java/com/yahoo/sketches/vector/matrix/MatrixTest.java
index fe4ad78..a2262d7 100644
--- a/src/test/java/com/yahoo/sketches/vector/matrix/MatrixTest.java
+++ b/src/test/java/com/yahoo/sketches/vector/matrix/MatrixTest.java
@@ -21,24 +21,25 @@
@Test
public void checkHeapify() {
- final Matrix m = Matrix.builder().setType(MatrixBuilder.Algo.OJALGO).build(3, 3);
+ final Matrix m = Matrix.builder().setType(MatrixType.OJALGO).build(3, 3);
final byte[] bytes = m.toByteArray();
final Memory mem = Memory.wrap(bytes);
println(MatrixPreambleUtil.preambleToString(mem));
- Matrix tgt = Matrix.heapify(mem, MatrixBuilder.Algo.OJALGO);
+ Matrix tgt = Matrix.heapify(mem, MatrixType.OJALGO);
assertTrue(tgt instanceof MatrixImplOjAlgo);
checkMatrixEquality(m, tgt);
- tgt = Matrix.heapify(mem, MatrixBuilder.Algo.NATIVE);
- assertNull(tgt);
+ tgt = Matrix.heapify(mem, MatrixType.MTJ);
+ assertTrue(tgt instanceof MatrixImplMTJ);
+ checkMatrixEquality(m, tgt);
}
@Test
public void checkWrap() {
assertNull(Matrix.wrap(null));
- final Matrix src = Matrix.builder().setType(MatrixBuilder.Algo.OJALGO).build(3, 3);
+ final Matrix src = Matrix.builder().setType(MatrixType.OJALGO).build(3, 3);
final Object obj = src.getRawObject();
final Matrix tgt = Matrix.wrap(obj);
assertTrue(tgt instanceof MatrixImplOjAlgo);