WIP: FD-based ridge regression
diff --git a/pom.xml b/pom.xml
index cd7c287..9cdf625 100644
--- a/pom.xml
+++ b/pom.xml
@@ -166,6 +166,11 @@
<artifactId>datasketches-memory</artifactId>
<version>${datasketches-memory.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.checkerframework</groupId>
+ <artifactId>checker-qual</artifactId>
+ <version>3.10.0</version>
+ </dependency>
<!--
<dependency>
<groupId>org.apache.commons</groupId>
@@ -173,6 +178,11 @@
<version>${commons-math3.version}</version>
</dependency>
-->
+ <dependency>
+ <groupId>com.google.errorprone</groupId>
+ <artifactId>javac</artifactId>
+ <version>9+181-r4173-1</version>
+ </dependency>
<!-- END: UNIQUE FOR THIS JAVA COMPONENT -->
<!-- Test Scope -->
diff --git a/src/main/java/org/apache/datasketches/vector/SketchesArgumentException.java b/src/main/java/org/apache/datasketches/vector/SketchesArgumentException.java
new file mode 100644
index 0000000..c1d33d9
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/SketchesArgumentException.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector;
+
+public class SketchesArgumentException extends RuntimeException {
+ public SketchesArgumentException(final String message) {
+ super(message);
+ }
+
+ public SketchesArgumentException(final String message, final Throwable throwable) {
+ super(message, throwable);
+ }
+}
diff --git a/src/main/java/org/apache/datasketches/vector/regression/RidgeRegression.java b/src/main/java/org/apache/datasketches/vector/regression/RidgeRegression.java
new file mode 100644
index 0000000..6057be1
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/regression/RidgeRegression.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import org.apache.datasketches.vector.decomposition.FrequentDirections;
+import org.apache.datasketches.vector.matrix.Matrix;
+import org.ojalgo.function.PrimitiveFunction;
+import org.ojalgo.function.aggregator.Aggregator;
+import org.ojalgo.function.constant.PrimitiveMath;
+import org.ojalgo.matrix.decomposition.SingularValue;
+import org.ojalgo.matrix.store.MatrixStore;
+import org.ojalgo.matrix.store.Primitive64Store;
+import org.ojalgo.matrix.store.SparseStore;
+
+public class RidgeRegression {
+
+ private final int k_;
+ private final double gamma_;
+ private final boolean useRobust_;
+
+ private Primitive64Store xOffset_;
+ private Primitive64Store xScale_;
+
+ private Primitive64Store weights_;
+ private double intercept_;
+
+ private long n_;
+ private int d_;
+
+
+ public RidgeRegression(final int k, final double gamma, final boolean useRobust) {
+ k_ = k;
+ gamma_ = gamma;
+ useRobust_ = useRobust;
+
+ n_ = 0;
+ d_ = 0;
+ }
+
+ public void fit(Matrix data, double[] targets) {
+ fit(data, targets, false);
+ }
+
+ /**
+ *
+ * @param data an n x d data Matrix, with one input vector per row. MODIFIES INPUT DATA
+ * @param targets an n-dimensional array of regression target values
+ * @param exact if true, computes exact solution, otherwise an approximation
+ * @return test error on the data set
+ */
+ public double fit(Matrix data, double[] targets, boolean exact) {
+ n_ = data.getNumRows();
+ d_ = (int) data.getNumColumns();
+
+ // preallocate the structures we'll use
+ xOffset_ = Primitive64Store.FACTORY.make(1, d_);
+ xScale_ = Primitive64Store.FACTORY.make(1, d_);
+ weights_ = Primitive64Store.FACTORY.make(1, d_);
+
+ preprocessData(data, targets, true, true);
+ solve(data, targets, exact);
+
+ double[] predictions = predict(data, true);
+ return getError(predictions, targets);
+ }
+
+ public double[] predict(final Matrix data) {
+ return predict(data, false);
+ }
+
+ private double[] predict(final Matrix data, final boolean preNormalized) {
+ if (data.getNumColumns() != d_)
+ throw new RuntimeException("Input matrix for prediction must have " + d_ + " columns, found " + data.getNumColumns());
+
+ final Primitive64Store mtx = (Primitive64Store) data.getRawObject();
+ final MatrixStore<Double> rawPredictions;
+
+ if (preNormalized) {
+ rawPredictions = mtx.multiply(weights_);
+ } else {
+ Primitive64Store adjustedMtx = mtx.copy();
+ adjustedMtx.modifyMatchingInRows(PrimitiveMath.SUBTRACT, xOffset_);
+ adjustedMtx.modifyMatchingInRows(PrimitiveMath.DIVIDE, xScale_);
+ rawPredictions = adjustedMtx.multiply(weights_);
+ }
+
+ rawPredictions.onAll(PrimitiveMath.ADD, intercept_);
+ return rawPredictions.toRawCopy1D();
+ }
+
+ public double getError(final double[] y_pred, final double[] y_true) {
+ if (y_pred.length != y_true.length)
+ throw new RuntimeException("Predictions and true value vectors differ in length: "
+ + y_pred.length + " != " + y_true.length);
+
+ double cumSqErr = 0.0;
+ for (int i = 0; i < y_pred.length; ++i) {
+ double val = y_pred[i] - y_true[i];
+ cumSqErr += val * val;
+ }
+
+ return Math.sqrt(cumSqErr) / Math.sqrt(y_pred.length);
+ }
+
+ public double[] getWeights() {
+ return weights_.data.clone();
+ }
+
+ public double getIntercept() {
+ return intercept_;
+ }
+
+ private void preprocessData(Matrix data, double[] targets, boolean fitIntercept, boolean normalize) {
+ Primitive64Store mtx = (Primitive64Store) data.getRawObject();
+
+ if (fitIntercept) {
+ mtx.reduceColumns(Aggregator.AVERAGE).supplyTo(xOffset_);
+ intercept_ = Primitive64Store.wrap(targets).aggregateAll(Aggregator.AVERAGE);
+
+ // subtract xOffset from input matrix, yOffset from targets
+ mtx.modifyMatchingInRows(PrimitiveMath.SUBTRACT, xOffset_);
+ for (int r = 0; r < n_; ++r) {
+ targets[r] -= intercept_;
+ }
+
+ if (normalize) {
+ mtx.reduceColumns(Aggregator.NORM2).supplyTo(xScale_);
+ // map any zeros to 1.0 and adjust from norm2 to stdev
+ PrimitiveFunction.Unary fixZero = arg -> arg == 0.0 ? 1.0 : Math.sqrt(arg * arg / n_);
+ xScale_.modifyAll(fixZero);
+ mtx.modifyMatchingInRows(PrimitiveMath.DIVIDE, xScale_);
+ } else {
+ xScale_.fillAll(1.0);
+ }
+ } else {
+ xOffset_.fillAll(0.0);
+ xScale_.fillAll(1.0);
+ intercept_ = 0.0;
+ }
+ }
+
+
+ private void solve(Matrix data, double[] targets, boolean exact) {
+ final Primitive64Store sketchMtx;
+ final MatrixStore<Double> Vt;
+ final double[] sv;
+ final int nDim;
+ if (exact) {
+ nDim = d_;
+ sketchMtx = (Primitive64Store) data.getRawObject();
+ final SingularValue<Double> svd = SingularValue.PRIMITIVE.make(sketchMtx);
+ svd.decompose(sketchMtx);
+ sv = new double[nDim];
+ svd.getSingularValues(sv);
+ Vt = svd.getV().transpose();
+ } else {
+ final FrequentDirections fd = FrequentDirections.newInstance(k_, d_);
+ for (int r = 0; r < data.getNumRows(); ++r) {
+ fd.update(data.getRow(r));
+ }
+ fd.forceReduceRank();
+
+ sv = fd.getSingularValues(useRobust_);
+ Vt = (Primitive64Store) fd.getProjectionMatrix().getRawObject();
+ nDim = (int) Vt.countRows();
+ }
+
+ final MatrixStore<Double> ATy = ((Primitive64Store) data.getRawObject()).transpose().multiply(Primitive64Store.wrap(targets));
+ // TODO: seems there should be a modifyDiagonal() to apply this?
+ final SparseStore<Double> invDiag = SparseStore.makePrimitive(nDim, nDim);
+ for (int i = 0; i < sv.length; ++i) {
+ invDiag.set(i, i, 1.0/(sv[i] * sv[i] + gamma_));
+ }
+
+ MatrixStore<Double> firstTerm = (Vt.transpose().multiply(invDiag)).multiply(Vt.multiply(ATy));
+ MatrixStore<Double> secondTerm = ATy.multiply(1.0 / gamma_);
+ MatrixStore<Double> thirdTerm = Vt.transpose().multiply(1.0 / gamma_).multiply( Vt.multiply(ATy) );
+
+ firstTerm.add(secondTerm).subtract(thirdTerm).supplyTo(weights_);
+ }
+}
diff --git a/src/main/java/org/apache/datasketches/vector/regression/RobustRidgeRegression.java b/src/main/java/org/apache/datasketches/vector/regression/RobustRidgeRegression.java
new file mode 100644
index 0000000..5a6d96e
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/regression/RobustRidgeRegression.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import org.apache.datasketches.vector.SketchesArgumentException;
+import org.ojalgo.array.Array1D;
+import org.ojalgo.function.constant.PrimitiveMath;
+import org.ojalgo.matrix.decomposition.SingularValue;
+import org.ojalgo.matrix.store.MatrixStore;
+import org.ojalgo.matrix.store.Primitive64Store;
+
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.ojalgo.matrix.store.SparseStore;
+
+public class RobustRidgeRegression {
+ private final int d_;
+ private final int k_;
+ private final int l_; // convenience value so we don't need to compute 2*k frequently
+
+ private Primitive64Store xOffset_;
+ private Primitive64Store xScale_;
+
+ private final Primitive64Store B_; // Sketch of the data
+ private final Primitive64Store ATyAccum_; // Accumulates A^T y - a d-dim vector
+ private double intercept_;
+
+ private double gamma_;
+ private long n_;
+ private int nextZeroRow_;
+
+ private Primitive64Store weights_;
+
+ // transient values for SVD
+ private double[] sv_;
+ private Primitive64Store Vt_;
+ private SparseStore<Double> S_; // to hold singular value matrix
+ private SingularValue<Double> svd_;
+ //private Eigenvalue<Double> evd_;
+
+
+ /**
+ * Returns an object ready to accept data for Robust Frequent Directions Ridge Regression. The sketch size
+ * parameter <tt>k</tt> controls the accuracy/size trade-off, but must be no greater than <tt>2 d</tt>.
+ * @param gamma A nonnegative regularization coefficient
+ * @param d The number of dimensions in each input vector
+ * @param k The sketch size parameter
+ */
+ RobustRidgeRegression(final double gamma, final int d, final int k) {
+ if (gamma < 0.0)
+ throw new SketchesArgumentException("Gamma must be nonnegative. Found: " + gamma);
+ if (k < 1)
+ throw new SketchesArgumentException("k must be at least 1. Found: " + k);
+ if (d < 2 * k)
+ throw new SketchesArgumentException("d must be at least 2k. Found d=" + d + ", k=" + k);
+
+ gamma_ = gamma;
+ d_ = d;
+ k_ = k;
+ l_ = 2 * k;
+
+ n_ = 0;
+ nextZeroRow_ = 0;
+
+ B_ = Primitive64Store.FACTORY.make(l_, d_);
+ ATyAccum_ = Primitive64Store.FACTORY.make(d_, 1);
+ }
+
+ /**
+ * Returns an object ready to accept data for Robust Frequent Directions Ridge Regression. Uses a default
+ * sketch size of the maximum supported for a given value of <tt>d</tt>.
+ * @param gamma A nonnegative regularization coefficient
+ * @param d The number of dimensions in each input vector
+ */
+ RobustRidgeRegression(final double gamma, final int d) {
+ this(gamma, d, d / 2);
+ }
+
+ /**
+ * Initializes mean/variance normalization based on a VectorNormalizer object. The input normalizer
+ * must match the configured dimensionality of the regression object.
+ * @param normalizer A VectorNormalizer run on (a sample of) the data to be modeled.
+ */
+ public void setNormalization(@NonNull final VectorNormalizer normalizer) {
+ if (normalizer.getD() != d_)
+ throw new SketchesArgumentException("VectorNormalizer dimension must match configured dimensions. "
+ + normalizer.getD() + " != " + d_);
+
+ xOffset_ = Primitive64Store.wrap(normalizer.getMean());
+ xScale_ = Primitive64Store.wrap(normalizer.getSampleVariance()); // variance, not std. deviation
+ xScale_.modifyAll(PrimitiveMath.SQRT);
+ intercept_ = normalizer.getIntercept();
+ }
+
+ /**
+ * Initializes mean/variance normalization using specified arrays. Note that the second argument
+ * is an array of <em>variance</em> values, not standard deviations. Both arrays must match the
+ * configured dimensionality of the regression object.
+ * @param means An array of mean values
+ * @param variances An array of variance values
+ */
+ public void setNormalization(final double[] means, final double[] variances, final double intercept) {
+ if (means == null || variances == null)
+ throw new SketchesArgumentException("Mean and variance arrays cannot be null.");
+ if (means.length != d_ || variances.length != d_)
+ throw new SketchesArgumentException("Mean and variance arrays must be of length " + d_);
+
+ xOffset_ = Primitive64Store.wrap(means.clone());
+ xScale_ = Primitive64Store.wrap(variances.clone()); // variance, not std. deviation
+ xScale_.modifyAll(PrimitiveMath.SQRT);
+ intercept_ = intercept;
+ }
+
+ // add a single vector to the sketch
+ public void update(final double[] data, final double target) {
+ if (data == null || data.length != d_)
+ throw new SketchesArgumentException("data must be a non-null vector of length " + d_);
+
+ update(Primitive64Store.wrap(data, 1), Primitive64Store.wrap(target));
+ }
+
+ // add multiple vectors to the sketch
+ public void update(@NonNull final Primitive64Store data, @NonNull final Primitive64Store targets) {
+ if (data.countColumns() != d_)
+ throw new SketchesArgumentException("data must have " + d_ + " columns. Found: " + data.countColumns());
+ if (data.countRows() != targets.count())
+ throw new SketchesArgumentException("number of rows in data (" + data.countRows() + ") does not match"
+ + " number of targete (" + targets.count() + ")");
+
+ // append rows to B_ until we have l_ of them, then reduce and adjust gamma
+ for (int i = 0; i < data.countRows(); ++i) {
+ Array1D<Double> row = data.sliceRow(i);
+
+ if (nextZeroRow_ == l_) {
+ reduceRank();
+ }
+
+ // accumulate values and copy row into B_, applying normalization if supplied
+ // TODO: may be able to improve performance by normalizing data first?
+ if (xOffset_ != null) {
+ for (int j = 0; j < d_; ++j) {
+ final double accumUpdate = ATyAccum_.get(j) + row.get(j) * (targets.get(i) - intercept_);
+ ATyAccum_.set(j, accumUpdate);
+ B_.set(i, j, (row.get(j) - xOffset_.get(j)) / xScale_.get(j));
+ }
+ } else {
+ for (int j = 0; j < d_; ++j) {
+ final double accumUpdate = ATyAccum_.get(j) + (row.get(j) * targets.get(i));
+ ATyAccum_.set(j, accumUpdate);
+ B_.set(i, j, row.get(j));
+ }
+ }
+
+ ++n_;
+ ++nextZeroRow_;
+ }
+ }
+
+ public void merge(@NonNull final RobustRidgeRegression other) {
+ // must match d, k (should be able to merge larger k into smaller?)
+ }
+
+ public double[] solve() {
+ if (weights_ == null) {
+ weights_ = Primitive64Store.FACTORY.make(1, d_);
+ }
+
+ // make sure any new data has contributed to V^T and singular values
+ reduceRank();
+
+ // TODO: seems there should be a modifyDiagonal() to apply this?
+ final SparseStore<Double> invDiag = SparseStore.makePrimitive(d_, d_);
+ for (int i = 0; i < sv_.length; ++i) {
+ invDiag.set(i, i, 1.0/(sv_[i] * sv_[i] + gamma_));
+ }
+
+ MatrixStore<Double> firstTerm = (Vt_.transpose().multiply(invDiag)).multiply(Vt_.multiply(ATyAccum_));
+ MatrixStore<Double> secondTerm = ATyAccum_.multiply(1.0 / gamma_);
+ MatrixStore<Double> thirdTerm = Vt_.transpose().multiply(1.0 / gamma_).multiply( Vt_.multiply(ATyAccum_) );
+
+ firstTerm.add(secondTerm).subtract(thirdTerm).supplyTo(weights_);
+
+ return weights_.toRawCopy1D();
+ }
+
+ private void reduceRank() {
+ if (nextZeroRow_ < k_) { return; }
+
+ if (svd_ == null) {
+ svd_ = SingularValue.PRIMITIVE.make(B_);
+ sv_ = new double[l_];
+ S_ = SparseStore.makePrimitive(sv_.length, sv_.length);
+ }
+
+ // full SVD
+ // computes U and V matrices even though we only use the latter
+ svd_.decompose(B_);
+ svd_.getV().transpose().supplyTo(Vt_);
+ svd_.getSingularValues(sv_);
+
+ // zero-out singular values and update gamma_
+ double medianSVSq = sv_[k_]; // (l_/2)th item, not yet squared
+ medianSVSq *= medianSVSq;
+ gamma_ += 0.5 * medianSVSq;
+ for (int i = 0; i < (k_ - 1); ++i) {
+ final double val = sv_[i];
+ final double adjSqSV = (val * val) - medianSVSq;
+ S_.set(i, i, adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV)); // just to be safe
+ }
+ for (int i = k_; i < S_.countColumns(); ++i) {
+ S_.set(i, i, 0.0);
+ }
+
+ // store the result back in B_
+ S_.multiply(Vt_, B_);
+
+
+ // update bookkeeping now
+ nextZeroRow_ = (int) Math.min(k_ - 1, n_);
+ }
+
+ public static void main(String[] args) {
+ RobustRidgeRegression rr = new RobustRidgeRegression(1.0, 10);
+ rr.update(new double[10], 1.0);
+ rr.solve();
+ }
+}
diff --git a/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
index 49a5eec..0047e00 100644
--- a/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
+++ b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
@@ -24,6 +24,7 @@
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.vector.MatrixFamily;
+import org.checkerframework.checker.nullness.qual.NonNull;
/**
* Computes mean and variance for each of d dimensions of an input vector using Welford's online algorithm,
@@ -43,6 +44,12 @@
*
* || 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
* 1 ||-------------------------Num. Vectors Processed (n)--------------------------|
+ *
+ * || 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
+ * 2 ||---------------------------Intercept (target mean)---------------------------|
+ *
+ * || 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
+ * 3 ||-----------------------------start of mean array-----------------------------|
* </pre>
*
* @author Jon Malkin
@@ -52,10 +59,11 @@
private final int d_;
private final double[] mean_;
private final double[] M2_;
+ private double intercept_;
private long n_;
// Preamble byte Addresses
- static final int PREAMBLE_LONGS_BYTE = 0;
+ static final int PREAMBLE_LONGS_BYTE = 0;
static final int SER_VER_BYTE = 1;
static final int FAMILY_BYTE = 2;
static final int FLAGS_BYTE = 3;
@@ -78,9 +86,10 @@
throw new IllegalArgumentException("d cannot be < 1. Found: " + d);
d_ = d;
+ n_ = 0;
+ intercept_ = 0.0;
mean_ = new double[d_];
M2_ = new double[d_];
- n_ = 0;
}
/**
@@ -90,13 +99,15 @@
public VectorNormalizer(final VectorNormalizer other) {
d_ = other.d_;
n_ = other.n_;
+ intercept_ = other.intercept_;
mean_ = other.mean_.clone();
M2_ = other.M2_.clone();
}
- private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2) {
+ private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2, final double intercept) {
d_ = d;
n_ = n;
+ intercept_ = intercept;
mean_ = mean;
M2_ = M2;
}
@@ -149,13 +160,16 @@
long offsetBytes = (long) preLongs * Long.BYTES;
// check capacity for the rest
- final long bytesNeeded = offsetBytes + (2L * d * Double.BYTES);
+ final long bytesNeeded = offsetBytes + (((2L * d) + 1) * Double.BYTES);
if (srcMem.getCapacity() < bytesNeeded) {
throw new IllegalArgumentException(
"Possible Corruption: Size of Memory not large enough: Size: " + srcMem.getCapacity()
+ ", Required: " + bytesNeeded);
}
+ final double intercept = srcMem.getDouble(offsetBytes);
+ offsetBytes += Double.BYTES;
+
final double[] mean = new double[d];
srcMem.getDoubleArray(offsetBytes, mean, 0, d);
offsetBytes += (long) d * Double.BYTES;
@@ -163,7 +177,7 @@
final double[] M2 = new double[d];
srcMem.getDoubleArray(offsetBytes, M2, 0, d);
- return new VectorNormalizer(d, n, mean, M2);
+ return new VectorNormalizer(d, n, mean, M2, intercept);
}
/**
@@ -178,7 +192,7 @@
? MatrixFamily.VECTORNORMALIZER.getMinPreLongs()
: MatrixFamily.VECTORNORMALIZER.getMaxPreLongs();
- final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : 2 * d_ * Double.BYTES);
+ final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : (1 + 2 * d_) * Double.BYTES);
final byte[] outArr = new byte[outBytes];
final WritableMemory memOut = WritableMemory.wrap(outArr);
final Object memObj = memOut.getArray();
@@ -193,6 +207,8 @@
if (!empty) {
insertN(memObj, memAddr, n_);
long offset = (long) preLongs * Long.BYTES;
+ memOut.putDouble(offset, intercept_);
+ offset += Double.BYTES;
memOut.putDoubleArray(offset, mean_, 0, d_);
offset += (long) d_ * Double.BYTES;
memOut.putDoubleArray(offset, M2_, 0, d_);
@@ -242,6 +258,17 @@
}
/**
+ * Returns the mean of the target value, aka the intercept
+ * @return Mean of the target value
+ */
+ public double getIntercept() {
+ if (n_ == 0)
+ return Double.NaN;
+ else
+ return intercept_;
+ }
+
+ /**
* Returns the sample variance array represented in this object. Returns an array of NaN if N = 0 and an
* array of zeros if N = 1.
* @return The sample variance array represented in this object
@@ -287,7 +314,7 @@
}
}
- public void update(double[] x) {
+ public void update(final double[] x, final double target) {
if (x == null)
return;
@@ -302,17 +329,18 @@
double d2 = x[i] - mean_[i]; // x_i - newMean_i
M2_[i] += d1 * d2;
}
+
+ double delta = target - intercept_;
+ intercept_ += delta / n_;
}
- public void merge(VectorNormalizer other) {
- if (other == null)
- return;
-
+ public void merge(@NonNull final VectorNormalizer other) {
if (other.d_ != d_)
throw new IllegalArgumentException("Input VectorNormalizer must have d= " + d_ + ". Found: " + other.d_);
long combinedN = n_ + other.n_;
double varCountScalar = (n_ * other.n_) / (double) combinedN; // n_A * n_B / (n_A + n_B)
+ intercept_ = ((n_ * intercept_) + (other.n_ * other.intercept_)) / combinedN;
for (int i = 0; i < d_; ++i) {
double meanDiff = other.mean_[i] - mean_[i];
mean_[i] = ((n_ * mean_[i]) + (other.n_ * other.mean_[i])) / combinedN;
@@ -325,7 +353,7 @@
if (n_ == 0) {
return MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES;
} else {
- return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + (2 * d_ * Double.BYTES);
+ return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + ((1 + 2 * d_) * Double.BYTES);
}
}
diff --git a/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
index 94da914..5d32560 100644
--- a/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
+++ b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
@@ -1,2 +1,159 @@
-package org.apache.datasketches.vector.regression;public class RidgeRegressionTest {
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import static org.testng.Assert.assertEquals;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+
+import org.apache.datasketches.vector.matrix.Matrix;
+import org.testng.annotations.Test;
+
+public class RidgeRegressionTest {
+
+ @Test
+ public void normalize() {
+ final int nRows = 5;
+ final int nCols = 2;
+ final Matrix m = Matrix.builder().build(nRows, nCols);
+ m.setElement(0, 0, 1);
+ m.setElement(1, 0, 2);
+ m.setElement(2, 0, 3);
+ m.setElement(3, 0, 4);
+ m.setElement(4, 0, 5);
+ m.setElement(0, 1, 10);
+ m.setElement(1, 1, 20);
+ m.setElement(2, 1, 30);
+ m.setElement(3, 1, 40);
+ m.setElement(4, 1, 50);
+
+ final double[] targets = new double[] {-1, 1, 0, -1, 0.5};
+
+ //RidgeRegression rr = new RidgeRegression(5, 1.0, true);
+ //rr.fit(m, targets);
+ }
+
+ @Test
+ public void basicExactRegression() {
+ final int nRows = 5;
+ final int nCols = 2;
+ final Matrix m = Matrix.builder().build(nRows, nCols);
+ m.setElement(0, 0, 2);
+ m.setElement(1, 0, 3);
+ m.setElement(2, 0, 5);
+ m.setElement(3, 0, 7);
+ m.setElement(4, 0, 9);
+ m.setColumn(1, new double[]{0,0,0,0,0});
+ final double[] targets = new double[] {4, 5, 7, 10, 15};
+
+ RidgeRegression rr = new RidgeRegression(5, 0.0, false);
+ rr.fit(m, targets, true);
+ System.out.println("Weights:");
+ for (int i = 0; i < nCols; ++i) {
+ System.out.println("\t" + i + ":\t" + rr.getWeights()[i]);
+ }
+ //System.out.println("Slope: " + rr.getWeights()[0]);
+ System.out.println("Intercept: " + rr.getIntercept());
+ }
+
+ @Test
+ public void YearDataTest() {
+ final int nTrain = 16000;
+ final int nValid = 4000;
+ final int nTest = 5000;
+ final String path = "/Users/jmalkin/projects/FrequentDirectionsRidgeRegression/notebooks/SongPredictions/";
+ //Matrix fullTrain = loadTSVData(path + "years_train.tsv", nTrain);
+ //Matrix fullTest = loadTSVData(path + "years_test.tsv", nTest);
+ Matrix fullTrain = loadTSVData(path + "years_train.out", nTrain);
+ Matrix fullValid = loadTSVData(path + "years_valid.out", nValid);
+ Matrix fullTest = loadTSVData(path + "years_test.out", nTest);
+
+ final int d = (int) fullTrain.getNumColumns() - 1;
+ assertEquals(d, fullTest.getNumColumns() - 1);
+ assertEquals(nTrain, fullTrain.getNumRows());
+ assertEquals(nValid, fullValid.getNumRows());
+ assertEquals(nTest, fullTest.getNumRows());
+
+ // last column is targets
+ double[] yTrain = fullTrain.getColumn(d);
+ double[] yValid = fullValid.getColumn(d);
+ double[] yTest = fullTest.getColumn(d);
+
+ // grab the rest as training sets
+ Matrix xTrain = Matrix.builder().build(nTrain, d);
+ Matrix xValid = Matrix.builder().build(nValid, d);
+ Matrix xTest = Matrix.builder().build(nTest, d);
+ for (int i = 0; i < d; ++i) {
+ xTrain.setColumn(i, fullTrain.getColumn(i));
+ xValid.setColumn(i, fullValid.getColumn(i));
+ xTest.setColumn(i, fullTest.getColumn(i));
+ }
+
+ RidgeRegression rr = new RidgeRegression(256, 10000.0, false);
+ double error = rr.fit(xTrain, yTrain, true);
+ System.out.print("[");
+ for (final double w : rr.getWeights()) {
+ System.out.print(w + "\t");
+ }
+ System.out.println("]");
+ System.out.println("Intercept: " + rr.getIntercept());
+
+ // (needlessly) computed as part of fit
+ System.out.println("Train error: " + error);
+
+ double[] pred = rr.predict(xValid);
+ error = rr.getError(pred, yValid);
+ System.out.println("Validation error: " + error);
+
+ pred = rr.predict(xTest);
+ error = rr.getError(pred, yTest);
+ System.out.println("Test error: " + error);
+ }
+
+ Matrix loadTSVData(final String inputFile, final int nRows) {
+ Matrix data = null;
+ int row = 0;
+ try (BufferedReader br = new BufferedReader(new FileReader(inputFile))) {
+ String line;
+ while ((line = br.readLine()) != null) {
+ String[] strValues = line.split("\t");
+ double[] values = new double[strValues.length];
+
+ for (int d = 0; d < strValues.length; ++d)
+ values[d] = Double.parseDouble(strValues[d]);
+
+ if (data == null) {
+ data = Matrix.builder().build(nRows, values.length);
+ }
+ data.setRow(row, values);
+ ++row;
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ assertEquals(row, nRows);
+ return data;
+ }
+
+
}
diff --git a/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
index e6a7807..6c55aa1 100644
--- a/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
+++ b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
@@ -34,7 +34,6 @@
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
-import javax.imageio.plugins.jpeg.JPEGImageWriteParam;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.datasketches.memory.Memory;
@@ -88,7 +87,8 @@
final VectorNormalizer vn = new VectorNormalizer(d);
final double[] input = {-1, 0, 0.5};
- vn.update(input);
+ final double target = 0.3;
+ vn.update(input, target);
assertEquals(vn.getN(), 1);
assertFalse(vn.isEmpty());
@@ -107,6 +107,8 @@
assertEquals(sampleVar[i], 0.0);
assertEquals(popVar[i], 0.0);
}
+ // intercept shoudl equal target
+ assertEquals(vn.getIntercept(), target);
}
@Test
@@ -123,7 +125,8 @@
input[0] = rand.nextGaussian(); // mean = 0.0, var = 1.0
input[1] = rand.nextDouble() * 2.0; // mean = 1.0, var = (2-0)^2/12 = 1/3
input[2] = rand.nextDouble() - 0.5; // mean = 0.0, var = (1-0)^2/12
- vn.update(input);
+ double target = rand.nextGaussian() - 1.0; // mean = -1.0
+ vn.update(input, target);
}
assertFalse(vn.isEmpty());
@@ -132,6 +135,7 @@
assertEquals(mean[0], 0.0, tol);
assertEquals(mean[1], 1.0, tol);
assertEquals(mean[2], 0.0, tol);
+ assertEquals(vn.getIntercept(), -1.0, tol);
// n is large enough that sample vs population variance won't matter for testing
final double[] sampleVar = vn.getSampleVariance();
@@ -165,15 +169,18 @@
// data expectations:
// dimension 0: zero-mean, unit-variance Gaussian, even after merging
// dimension 1: U[0,2] + U[2,4) -> U[0,4), so mean = 2.0 and var = 4^2/12=4/3
+ // target: N(-1,1) + N(1,1), so mean = 0.0 and variance unmeasured
final double[] input = new double[d];
for (int i = 0; i < n; ++i) {
input[0] = rand.nextGaussian();
input[1] = (rand.nextDouble() * 2.0) + 2.0;
- vn1.update(input);
+ double target = rand.nextGaussian() - 1.0;
+ vn1.update(input, target);
input[0] = rand.nextGaussian();
input[1] = rand.nextDouble() * 2.0;
- vn2.update(input);
+ target = rand.nextGaussian() + 1.0;
+ vn2.update(input, target);
}
vn1.merge(vn2);
@@ -182,6 +189,7 @@
final double[] mean = vn1.getMean();
assertEquals(mean[0], 0.0, tol);
assertEquals(mean[1], 2.0, tol);
+ assertEquals(vn1.getIntercept(), 0.0, tol);
// n is large enough that sample vs population variance won't matter for testing
final double[] sampleVar = vn1.getSampleVariance();
@@ -200,15 +208,12 @@
final double[] input = new double[d];
for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
- vn.update(input);
- assertEquals(vn.getN(), 1);
-
- vn.update(null);
+ vn.update(input, 0.0);
assertEquals(vn.getN(), 1);
try {
final double[] badInput = {1.0};
- vn.update(badInput);
+ vn.update(badInput, 0.0);
fail();
} catch (IllegalArgumentException e) {
// expected
@@ -223,10 +228,7 @@
double[] input = new double[d];
for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
- vn1.update(input);
- assertEquals(vn1.getN(), 1);
-
- vn1.merge(null);
+ vn1.update(input, 1.0);
assertEquals(vn1.getN(), 1);
// update with a non-empty VN with a different value of d
@@ -234,7 +236,7 @@
final VectorNormalizer vn2 = new VectorNormalizer(d2);
input = new double[d2];
for (int i = 0; i < d2; ++i) { input[i] = 1.0 * i; }
- vn2.update(input);
+ vn2.update(input, 2.0);
assertEquals(vn2.getN(), 1);
try {
@@ -258,7 +260,7 @@
for (int j = 0; j < d; ++j) {
input[j] = rand.nextDouble();
}
- vn.update(input);
+ vn.update(input, rand.nextGaussian());
}
final VectorNormalizer vnCopy = new VectorNormalizer(vn);
@@ -294,7 +296,7 @@
for (int j = 0; j < d; ++j) {
input[j] = rand.nextGaussian();
}
- vn.update(input);
+ vn.update(input, rand.nextDouble());
}
outBytes = vn.toByteArray();
@@ -304,6 +306,7 @@
assertFalse(rebuilt.isEmpty());
assertEquals(vn.getD(), rebuilt.getD());
assertEquals(vn.getN(), rebuilt.getN());
+ assertEquals(vn.getIntercept(), rebuilt.getIntercept());
final double[] originalMean = vn.getMean();
final double[] rebuiltMean = vn.getMean();
@@ -413,7 +416,7 @@
for (int j = 0; j < d; ++j) {
input[j] = rand.nextDouble();
}
- vn.update(input);
+ vn.update(input, rand.nextDouble());
}
assertFalse(vn.isEmpty());