package org.apache.datasketches.vector.regression;
import org.apache.datasketches.vector.SketchesArgumentException;
import org.ojalgo.array.Array1D;
import org.ojalgo.function.constant.PrimitiveMath;
import org.ojalgo.matrix.decomposition.SingularValue;
import org.checkerframework.checker.nullness.qual.NonNull;
public class RobustRidgeRegression {
private final int d_;
private final int k_;
private final int l_; // convenience value so we don't need to compute 2*k frequently
private Primitive64Store xOffset_;
private Primitive64Store xScale_;
private final Primitive64Store B_; // Sketch of the data
private final Primitive64Store ATyAccum_; // Accumulates A^T y - a d-dim vector
private double intercept_;
private double gamma_;
private long n_;
private int nextZeroRow_;
private Primitive64Store weights_;
// transient values for SVD
private double[] sv_;
private Primitive64Store Vt_;
private SparseStore<Double> S_; // to hold singular value matrix
private SingularValue<Double> svd_;
//private Eigenvalue<Double> evd_;
* Returns an object ready to accept data for Robust Frequent Directions Ridge Regression. The sketch size
* parameter <tt>k</tt> controls the accuracy/size trade-off, but must be no greater than <tt>2 d</tt>.
* @param gamma A nonnegative regularization coefficient
* @param d The number of dimensions in each input vector
* @param k The sketch size parameter
RobustRidgeRegression(final double gamma, final int d, final int k) {
if (gamma < 0.0)
throw new SketchesArgumentException("Gamma must be nonnegative. Found: " + gamma);
if (k < 1)
throw new SketchesArgumentException("k must be at least 1. Found: " + k);
if (d < 2 * k)
throw new SketchesArgumentException("d must be at least 2k. Found d=" + d + ", k=" + k);
gamma_ = gamma;
d_ = d;
k_ = k;
l_ = 2 * k;
n_ = 0;
nextZeroRow_ = 0;
B_ = Primitive64Store.FACTORY.make(l_, d_);
ATyAccum_ = Primitive64Store.FACTORY.make(d_, 1);
* Returns an object ready to accept data for Robust Frequent Directions Ridge Regression. Uses a default
* sketch size of the maximum supported for a given value of <tt>d</tt>.
* @param gamma A nonnegative regularization coefficient
* @param d The number of dimensions in each input vector
RobustRidgeRegression(final double gamma, final int d) {
this(gamma, d, d / 2);
* Initializes mean/variance normalization based on a VectorNormalizer object. The input normalizer
* must match the configured dimensionality of the regression object.
* @param normalizer A VectorNormalizer run on (a sample of) the data to be modeled.
public void setNormalization(@NonNull final VectorNormalizer normalizer) {
if (normalizer.getD() != d_)
throw new SketchesArgumentException("VectorNormalizer dimension must match configured dimensions. "
+ normalizer.getD() + " != " + d_);
xOffset_ = Primitive64Store.wrap(normalizer.getMean());
xScale_ = Primitive64Store.wrap(normalizer.getSampleVariance()); // variance, not std. deviation
intercept_ = normalizer.getIntercept();
* Initializes mean/variance normalization using specified arrays. Note that the second argument
* is an array of <em>variance</em> values, not standard deviations. Both arrays must match the
* configured dimensionality of the regression object.
* @param means An array of mean values
* @param variances An array of variance values
public void setNormalization(final double[] means, final double[] variances, final double intercept) {
if (means == null || variances == null)
throw new SketchesArgumentException("Mean and variance arrays cannot be null.");
if (means.length != d_ || variances.length != d_)
throw new SketchesArgumentException("Mean and variance arrays must be of length " + d_);
xOffset_ = Primitive64Store.wrap(means.clone());
xScale_ = Primitive64Store.wrap(variances.clone()); // variance, not std. deviation
intercept_ = intercept;
// add a single vector to the sketch
public void update(final double[] data, final double target) {
if (data == null || data.length != d_)
throw new SketchesArgumentException("data must be a non-null vector of length " + d_);
update(Primitive64Store.wrap(data, 1), Primitive64Store.wrap(target));
// add multiple vectors to the sketch
public void update(@NonNull final Primitive64Store data, @NonNull final Primitive64Store targets) {
if (data.countColumns() != d_)
throw new SketchesArgumentException("data must have " + d_ + " columns. Found: " + data.countColumns());
if (data.countRows() != targets.count())
throw new SketchesArgumentException("number of rows in data (" + data.countRows() + ") does not match"
+ " number of targete (" + targets.count() + ")");
// append rows to B_ until we have l_ of them, then reduce and adjust gamma
for (int i = 0; i < data.countRows(); ++i) {
Array1D<Double> row = data.sliceRow(i);
if (nextZeroRow_ == l_) {
// accumulate values and copy row into B_, applying normalization if supplied
// TODO: may be able to improve performance by normalizing data first?
if (xOffset_ != null) {
for (int j = 0; j < d_; ++j) {
final double accumUpdate = ATyAccum_.get(j) + row.get(j) * (targets.get(i) - intercept_);
ATyAccum_.set(j, accumUpdate);
B_.set(i, j, (row.get(j) - xOffset_.get(j)) / xScale_.get(j));
} else {
for (int j = 0; j < d_; ++j) {
final double accumUpdate = ATyAccum_.get(j) + (row.get(j) * targets.get(i));
ATyAccum_.set(j, accumUpdate);
B_.set(i, j, row.get(j));
public void merge(@NonNull final RobustRidgeRegression other) {
// must match d, k (should be able to merge larger k into smaller?)
public double[] solve() {
if (weights_ == null) {
weights_ = Primitive64Store.FACTORY.make(1, d_);
// make sure any new data has contributed to V^T and singular values
// TODO: seems there should be a modifyDiagonal() to apply this?
final SparseStore<Double> invDiag = SparseStore.makePrimitive(d_, d_);
for (int i = 0; i < sv_.length; ++i) {
invDiag.set(i, i, 1.0/(sv_[i] * sv_[i] + gamma_));
MatrixStore<Double> firstTerm = (Vt_.transpose().multiply(invDiag)).multiply(Vt_.multiply(ATyAccum_));
MatrixStore<Double> secondTerm = ATyAccum_.multiply(1.0 / gamma_);
MatrixStore<Double> thirdTerm = Vt_.transpose().multiply(1.0 / gamma_).multiply( Vt_.multiply(ATyAccum_) );
return weights_.toRawCopy1D();
private void reduceRank() {
if (nextZeroRow_ < k_) { return; }
if (svd_ == null) {
svd_ = SingularValue.PRIMITIVE.make(B_);
sv_ = new double[l_];
S_ = SparseStore.makePrimitive(sv_.length, sv_.length);
// full SVD
// computes U and V matrices even though we only use the latter
// zero-out singular values and update gamma_
double medianSVSq = sv_[k_]; // (l_/2)th item, not yet squared
medianSVSq *= medianSVSq;
gamma_ += 0.5 * medianSVSq;
for (int i = 0; i < (k_ - 1); ++i) {
final double val = sv_[i];
final double adjSqSV = (val * val) - medianSVSq;
S_.set(i, i, adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV)); // just to be safe
for (int i = k_; i < S_.countColumns(); ++i) {
S_.set(i, i, 0.0);
// store the result back in B_
S_.multiply(Vt_, B_);
// update bookkeeping now
nextZeroRow_ = (int) Math.min(k_ - 1, n_);
public static void main(String[] args) {
RobustRidgeRegression rr = new RobustRidgeRegression(1.0, 10);
rr.update(new double[10], 1.0);