WIP: FD-based ridge regression
diff --git a/pom.xml b/pom.xml
index cd7c287..9cdf625 100644
--- a/pom.xml
+++ b/pom.xml
@@ -166,6 +166,11 @@
       <artifactId>datasketches-memory</artifactId>
       <version>${datasketches-memory.version}</version>
     </dependency>
+    <dependency>
+      <groupId>org.checkerframework</groupId>
+      <artifactId>checker-qual</artifactId>
+      <version>3.10.0</version>
+    </dependency>
     <!--
     <dependency>
       <groupId>org.apache.commons</groupId>
@@ -173,6 +178,11 @@
       <version>${commons-math3.version}</version>
     </dependency>
     -->
+    <dependency>
+      <groupId>com.google.errorprone</groupId>
+      <artifactId>javac</artifactId>
+      <version>9+181-r4173-1</version>
+    </dependency>
     <!-- END: UNIQUE FOR THIS JAVA COMPONENT -->
 
     <!-- Test Scope -->
diff --git a/src/main/java/org/apache/datasketches/vector/SketchesArgumentException.java b/src/main/java/org/apache/datasketches/vector/SketchesArgumentException.java
new file mode 100644
index 0000000..c1d33d9
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/SketchesArgumentException.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector;
+
+public class SketchesArgumentException extends RuntimeException {
+  public SketchesArgumentException(final String message) {
+    super(message);
+  }
+
+  public SketchesArgumentException(final String message, final Throwable throwable) {
+    super(message, throwable);
+  }
+}
diff --git a/src/main/java/org/apache/datasketches/vector/regression/RidgeRegression.java b/src/main/java/org/apache/datasketches/vector/regression/RidgeRegression.java
new file mode 100644
index 0000000..6057be1
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/regression/RidgeRegression.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import org.apache.datasketches.vector.decomposition.FrequentDirections;
+import org.apache.datasketches.vector.matrix.Matrix;
+import org.ojalgo.function.PrimitiveFunction;
+import org.ojalgo.function.aggregator.Aggregator;
+import org.ojalgo.function.constant.PrimitiveMath;
+import org.ojalgo.matrix.decomposition.SingularValue;
+import org.ojalgo.matrix.store.MatrixStore;
+import org.ojalgo.matrix.store.Primitive64Store;
+import org.ojalgo.matrix.store.SparseStore;
+
+public class RidgeRegression {
+
+  private final int k_;
+  private final double gamma_;
+  private final boolean useRobust_;
+
+  private Primitive64Store xOffset_;
+  private Primitive64Store xScale_;
+
+  private Primitive64Store weights_;
+  private double intercept_;
+
+  private long n_;
+  private int d_;
+
+
+  public RidgeRegression(final int k, final double gamma, final boolean useRobust) {
+    k_ = k;
+    gamma_ = gamma;
+    useRobust_ = useRobust;
+
+    n_ = 0;
+    d_ = 0;
+  }
+
+  public void fit(Matrix data, double[] targets) {
+    fit(data, targets, false);
+  }
+
+  /**
+   *
+   * @param data an n x d data Matrix, with one input vector per row. MODIFIES INPUT DATA
+   * @param targets an n-dimensional array of regression target values
+   * @param exact if true, computes exact solution, otherwise an approximation
+   * @return test error on the data set
+   */
+  public double fit(Matrix data, double[] targets, boolean exact) {
+    n_ = data.getNumRows();
+    d_ = (int) data.getNumColumns();
+
+    // preallocate the structures we'll use
+    xOffset_ = Primitive64Store.FACTORY.make(1, d_);
+    xScale_ = Primitive64Store.FACTORY.make(1, d_);
+    weights_ = Primitive64Store.FACTORY.make(1, d_);
+
+    preprocessData(data, targets, true, true);
+    solve(data, targets, exact);
+
+    double[] predictions = predict(data, true);
+    return getError(predictions, targets);
+  }
+
+  public double[] predict(final Matrix data) {
+    return predict(data, false);
+  }
+
+  private double[] predict(final Matrix data, final boolean preNormalized) {
+    if (data.getNumColumns() != d_)
+      throw new RuntimeException("Input matrix for prediction must have " + d_ + " columns, found " + data.getNumColumns());
+
+    final Primitive64Store mtx = (Primitive64Store) data.getRawObject();
+    final MatrixStore<Double> rawPredictions;
+
+    if (preNormalized) {
+      rawPredictions = mtx.multiply(weights_);
+    } else {
+      Primitive64Store adjustedMtx = mtx.copy();
+      adjustedMtx.modifyMatchingInRows(PrimitiveMath.SUBTRACT, xOffset_);
+      adjustedMtx.modifyMatchingInRows(PrimitiveMath.DIVIDE, xScale_);
+      rawPredictions = adjustedMtx.multiply(weights_);
+    }
+
+    rawPredictions.onAll(PrimitiveMath.ADD, intercept_);
+    return rawPredictions.toRawCopy1D();
+  }
+
+  public double getError(final double[] y_pred, final double[] y_true) {
+    if (y_pred.length != y_true.length)
+      throw new RuntimeException("Predictions and true value vectors differ in length: "
+          + y_pred.length + " != " + y_true.length);
+
+    double cumSqErr = 0.0;
+    for (int i = 0; i < y_pred.length; ++i) {
+      double val = y_pred[i] - y_true[i];
+      cumSqErr += val * val;
+    }
+
+    return Math.sqrt(cumSqErr) / Math.sqrt(y_pred.length);
+  }
+
+  public double[] getWeights() {
+    return weights_.data.clone();
+  }
+
+  public double getIntercept() {
+    return intercept_;
+  }
+
+  private void preprocessData(Matrix data, double[] targets, boolean fitIntercept, boolean normalize) {
+    Primitive64Store mtx = (Primitive64Store) data.getRawObject();
+
+    if (fitIntercept) {
+      mtx.reduceColumns(Aggregator.AVERAGE).supplyTo(xOffset_);
+      intercept_ = Primitive64Store.wrap(targets).aggregateAll(Aggregator.AVERAGE);
+
+      // subtract xOffset from input matrix, yOffset from targets
+      mtx.modifyMatchingInRows(PrimitiveMath.SUBTRACT, xOffset_);
+      for (int r = 0; r < n_; ++r) {
+        targets[r] -= intercept_;
+      }
+
+      if (normalize) {
+        mtx.reduceColumns(Aggregator.NORM2).supplyTo(xScale_);
+        // map any zeros to 1.0 and adjust from norm2 to stdev
+        PrimitiveFunction.Unary fixZero = arg -> arg == 0.0 ? 1.0 : Math.sqrt(arg * arg / n_);
+        xScale_.modifyAll(fixZero);
+        mtx.modifyMatchingInRows(PrimitiveMath.DIVIDE, xScale_);
+      } else {
+        xScale_.fillAll(1.0);
+      }
+    } else {
+      xOffset_.fillAll(0.0);
+      xScale_.fillAll(1.0);
+      intercept_ = 0.0;
+    }
+  }
+
+
+  private void solve(Matrix data, double[] targets, boolean exact) {
+    final Primitive64Store sketchMtx;
+    final MatrixStore<Double> Vt;
+    final double[] sv;
+    final int nDim;
+    if (exact) {
+      nDim = d_;
+      sketchMtx = (Primitive64Store) data.getRawObject();
+      final SingularValue<Double> svd = SingularValue.PRIMITIVE.make(sketchMtx);
+      svd.decompose(sketchMtx);
+      sv = new double[nDim];
+      svd.getSingularValues(sv);
+      Vt = svd.getV().transpose();
+    } else {
+      final FrequentDirections fd = FrequentDirections.newInstance(k_, d_);
+      for (int r = 0; r < data.getNumRows(); ++r) {
+        fd.update(data.getRow(r));
+      }
+      fd.forceReduceRank();
+
+      sv = fd.getSingularValues(useRobust_);
+      Vt = (Primitive64Store) fd.getProjectionMatrix().getRawObject();
+      nDim = (int) Vt.countRows();
+    }
+
+    final MatrixStore<Double> ATy = ((Primitive64Store) data.getRawObject()).transpose().multiply(Primitive64Store.wrap(targets));
+    // TODO: seems there should be a modifyDiagonal() to apply this?
+    final SparseStore<Double> invDiag = SparseStore.makePrimitive(nDim, nDim);
+    for (int i = 0; i < sv.length; ++i) {
+      invDiag.set(i, i, 1.0/(sv[i] * sv[i] + gamma_));
+    }
+
+    MatrixStore<Double> firstTerm = (Vt.transpose().multiply(invDiag)).multiply(Vt.multiply(ATy));
+    MatrixStore<Double> secondTerm = ATy.multiply(1.0 / gamma_);
+    MatrixStore<Double> thirdTerm = Vt.transpose().multiply(1.0 / gamma_).multiply( Vt.multiply(ATy) );
+
+    firstTerm.add(secondTerm).subtract(thirdTerm).supplyTo(weights_);
+  }
+}
diff --git a/src/main/java/org/apache/datasketches/vector/regression/RobustRidgeRegression.java b/src/main/java/org/apache/datasketches/vector/regression/RobustRidgeRegression.java
new file mode 100644
index 0000000..5a6d96e
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/regression/RobustRidgeRegression.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import org.apache.datasketches.vector.SketchesArgumentException;
+import org.ojalgo.array.Array1D;
+import org.ojalgo.function.constant.PrimitiveMath;
+import org.ojalgo.matrix.decomposition.SingularValue;
+import org.ojalgo.matrix.store.MatrixStore;
+import org.ojalgo.matrix.store.Primitive64Store;
+
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.ojalgo.matrix.store.SparseStore;
+
+public class RobustRidgeRegression {
+  private final int d_;
+  private final int k_;
+  private final int l_; // convenience value so we don't need to compute 2*k frequently
+
+  private Primitive64Store xOffset_;
+  private Primitive64Store xScale_;
+
+  private final Primitive64Store B_;          // Sketch of the data
+  private final Primitive64Store ATyAccum_;   // Accumulates A^T y - a d-dim vector
+  private double intercept_;
+
+  private double gamma_;
+  private long n_;
+  private int nextZeroRow_;
+
+  private Primitive64Store weights_;
+
+  // transient values for SVD
+  private double[] sv_;
+  private Primitive64Store Vt_;
+  private SparseStore<Double> S_; // to hold singular value matrix
+  private SingularValue<Double> svd_;
+  //private Eigenvalue<Double> evd_;
+
+
+  /**
+   * Returns an object ready to accept data for Robust Frequent Directions Ridge Regression. The sketch size
+   * parameter <tt>k</tt> controls the accuracy/size trade-off, but must be no greater than <tt>2 d</tt>.
+   * @param gamma A nonnegative regularization coefficient
+   * @param d The number of dimensions in each input vector
+   * @param k The sketch size parameter
+   */
+  RobustRidgeRegression(final double gamma, final int d, final int k) {
+    if (gamma < 0.0)
+      throw new SketchesArgumentException("Gamma must be nonnegative. Found: " + gamma);
+    if (k < 1)
+      throw new SketchesArgumentException("k must be at least 1. Found: " + k);
+    if (d < 2 * k)
+      throw new SketchesArgumentException("d must be at least 2k. Found d=" + d + ", k=" + k);
+
+    gamma_ = gamma;
+    d_ = d;
+    k_ = k;
+    l_ = 2 * k;
+
+    n_ = 0;
+    nextZeroRow_ = 0;
+
+    B_ = Primitive64Store.FACTORY.make(l_, d_);
+    ATyAccum_ = Primitive64Store.FACTORY.make(d_, 1);
+  }
+
+  /**
+   * Returns an object ready to accept data for Robust Frequent Directions Ridge Regression. Uses a default
+   * sketch size of the maximum supported for a given value of <tt>d</tt>.
+   * @param gamma A nonnegative regularization coefficient
+   * @param d The number of dimensions in each input vector
+   */
+  RobustRidgeRegression(final double gamma, final int d) {
+    this(gamma, d, d / 2);
+  }
+
+  /**
+   * Initializes mean/variance normalization based on a VectorNormalizer object. The input normalizer
+   * must match the configured dimensionality of the regression object.
+   * @param normalizer A VectorNormalizer run on (a sample of) the data to be modeled.
+   */
+  public void setNormalization(@NonNull final VectorNormalizer normalizer) {
+    if (normalizer.getD() != d_)
+      throw new SketchesArgumentException("VectorNormalizer dimension must match configured dimensions. "
+          + normalizer.getD() + " != " + d_);
+
+    xOffset_ = Primitive64Store.wrap(normalizer.getMean());
+    xScale_ = Primitive64Store.wrap(normalizer.getSampleVariance()); // variance, not std. deviation
+    xScale_.modifyAll(PrimitiveMath.SQRT);
+    intercept_ = normalizer.getIntercept();
+  }
+
+  /**
+   * Initializes mean/variance normalization using specified arrays. Note that the second argument
+   * is an array of <em>variance</em> values, not standard deviations. Both arrays must match the
+   * configured dimensionality of the regression object.
+   * @param means An array of mean values
+   * @param variances An array of variance values
+   */
+  public void setNormalization(final double[] means, final double[] variances, final double intercept) {
+    if (means == null || variances == null)
+      throw new SketchesArgumentException("Mean and variance arrays cannot be null.");
+    if (means.length != d_ || variances.length != d_)
+      throw new SketchesArgumentException("Mean and variance arrays must be of length " + d_);
+
+    xOffset_ = Primitive64Store.wrap(means.clone());
+    xScale_ = Primitive64Store.wrap(variances.clone()); // variance, not std. deviation
+    xScale_.modifyAll(PrimitiveMath.SQRT);
+    intercept_ = intercept;
+  }
+
+  // add a single vector to the sketch
+  public void update(final double[] data, final double target) {
+    if (data == null || data.length != d_)
+      throw new SketchesArgumentException("data must be a non-null vector of length " + d_);
+
+    update(Primitive64Store.wrap(data, 1), Primitive64Store.wrap(target));
+  }
+
+  // add multiple vectors to the sketch
+  public void update(@NonNull final Primitive64Store data, @NonNull final Primitive64Store targets) {
+    if (data.countColumns() != d_)
+      throw new SketchesArgumentException("data must have " + d_ + " columns. Found: " + data.countColumns());
+    if (data.countRows() != targets.count())
+      throw new SketchesArgumentException("number of rows in data (" + data.countRows() + ") does not match"
+          + " number of targete (" + targets.count() + ")");
+
+    // append rows to B_ until we have l_ of them, then reduce and adjust gamma
+    for (int i = 0; i < data.countRows(); ++i) {
+      Array1D<Double> row = data.sliceRow(i);
+
+      if (nextZeroRow_ == l_) {
+        reduceRank();
+      }
+
+      // accumulate values and copy row into B_, applying normalization if supplied
+      // TODO: may be able to improve performance by normalizing data first?
+      if (xOffset_ != null) {
+        for (int j = 0; j < d_; ++j) {
+          final double accumUpdate = ATyAccum_.get(j) + row.get(j) * (targets.get(i) - intercept_);
+          ATyAccum_.set(j, accumUpdate);
+          B_.set(i, j, (row.get(j) - xOffset_.get(j)) / xScale_.get(j));
+        }
+      } else {
+        for (int j = 0; j < d_; ++j) {
+          final double accumUpdate = ATyAccum_.get(j) + (row.get(j) * targets.get(i));
+          ATyAccum_.set(j, accumUpdate);
+          B_.set(i, j, row.get(j));
+        }
+      }
+
+      ++n_;
+      ++nextZeroRow_;
+    }
+  }
+
+  public void merge(@NonNull final RobustRidgeRegression other) {
+    // must match d, k (should be able to merge larger k into smaller?)
+  }
+
+  public double[] solve() {
+    if (weights_ == null) {
+      weights_ = Primitive64Store.FACTORY.make(1, d_);
+    }
+
+    // make sure any new data has contributed to V^T and singular values
+    reduceRank();
+
+    // TODO: seems there should be a modifyDiagonal() to apply this?
+    final SparseStore<Double> invDiag = SparseStore.makePrimitive(d_, d_);
+    for (int i = 0; i < sv_.length; ++i) {
+      invDiag.set(i, i, 1.0/(sv_[i] * sv_[i] + gamma_));
+    }
+
+    MatrixStore<Double> firstTerm = (Vt_.transpose().multiply(invDiag)).multiply(Vt_.multiply(ATyAccum_));
+    MatrixStore<Double> secondTerm = ATyAccum_.multiply(1.0 / gamma_);
+    MatrixStore<Double> thirdTerm = Vt_.transpose().multiply(1.0 / gamma_).multiply( Vt_.multiply(ATyAccum_) );
+
+    firstTerm.add(secondTerm).subtract(thirdTerm).supplyTo(weights_);
+
+    return weights_.toRawCopy1D();
+  }
+
+  private void reduceRank() {
+    if (nextZeroRow_ < k_) { return; }
+
+    if (svd_ == null) {
+      svd_ = SingularValue.PRIMITIVE.make(B_);
+      sv_ = new double[l_];
+      S_ = SparseStore.makePrimitive(sv_.length, sv_.length);
+    }
+
+    // full SVD
+    // computes U and V matrices even though we only use the latter
+    svd_.decompose(B_);
+    svd_.getV().transpose().supplyTo(Vt_);
+    svd_.getSingularValues(sv_);
+
+    // zero-out singular values and update gamma_
+    double medianSVSq = sv_[k_]; // (l_/2)th item, not yet squared
+    medianSVSq *= medianSVSq;
+    gamma_ += 0.5 * medianSVSq;
+    for (int i = 0; i < (k_ - 1); ++i) {
+      final double val = sv_[i];
+      final double adjSqSV = (val * val) - medianSVSq;
+      S_.set(i, i, adjSqSV < 0 ? 0.0 : Math.sqrt(adjSqSV)); // just to be safe
+    }
+    for (int i = k_; i < S_.countColumns(); ++i) {
+      S_.set(i, i, 0.0);
+    }
+
+    // store the result back in B_
+    S_.multiply(Vt_, B_);
+
+
+    // update bookkeeping now
+    nextZeroRow_ = (int) Math.min(k_ - 1, n_);
+  }
+
+  public static void main(String[] args) {
+    RobustRidgeRegression rr = new RobustRidgeRegression(1.0, 10);
+    rr.update(new double[10], 1.0);
+    rr.solve();
+  }
+}
diff --git a/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
index 49a5eec..0047e00 100644
--- a/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
+++ b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
@@ -24,6 +24,7 @@
 import org.apache.datasketches.memory.Memory;
 import org.apache.datasketches.memory.WritableMemory;
 import org.apache.datasketches.vector.MatrixFamily;
+import org.checkerframework.checker.nullness.qual.NonNull;
 
 /**
  * Computes mean and variance for each of d dimensions of an input vector using Welford's online algorithm,
@@ -43,6 +44,12 @@
  *
  *      ||       8        |   9    |   10   |   11   |   12  |   13   |   14   |  15   |
  *  1   ||-------------------------Num. Vectors Processed (n)--------------------------|
+ *
+ *      ||       16       |   17   |   18   |   19   |   20  |   21   |   22   |  23   |
+ *  2   ||---------------------------Intercept (target mean)---------------------------|
+ *
+ *      ||       24       |   25   |   26   |   27   |   28  |   29   |   30   |  31   |
+ *  3   ||-----------------------------start of mean array-----------------------------|
  * </pre>
  *
  * @author Jon Malkin
@@ -52,10 +59,11 @@
   private final int d_;
   private final double[] mean_;
   private final double[] M2_;
+  private double intercept_;
   private long n_;
 
   // Preamble byte Addresses
-  static final int PREAMBLE_LONGS_BYTE = 0;
+  static final int PREAMBLE_LONGS_BYTE   = 0;
   static final int SER_VER_BYTE          = 1;
   static final int FAMILY_BYTE           = 2;
   static final int FLAGS_BYTE            = 3;
@@ -78,9 +86,10 @@
       throw new IllegalArgumentException("d cannot be < 1. Found: " + d);
 
     d_ = d;
+    n_ = 0;
+    intercept_ = 0.0;
     mean_ = new double[d_];
     M2_ = new double[d_];
-    n_ = 0;
   }
 
   /**
@@ -90,13 +99,15 @@
   public VectorNormalizer(final VectorNormalizer other) {
     d_ = other.d_;
     n_ = other.n_;
+    intercept_ = other.intercept_;
     mean_ = other.mean_.clone();
     M2_ = other.M2_.clone();
   }
 
-  private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2) {
+  private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2, final double intercept) {
     d_ = d;
     n_ = n;
+    intercept_ = intercept;
     mean_ = mean;
     M2_ = M2;
   }
@@ -149,13 +160,16 @@
     long offsetBytes = (long) preLongs * Long.BYTES;
 
     // check capacity for the rest
-    final long bytesNeeded = offsetBytes + (2L * d * Double.BYTES);
+    final long bytesNeeded = offsetBytes + (((2L * d) + 1) * Double.BYTES);
     if (srcMem.getCapacity() < bytesNeeded) {
       throw new IllegalArgumentException(
           "Possible Corruption: Size of Memory not large enough: Size: " + srcMem.getCapacity()
               + ", Required: " + bytesNeeded);
     }
 
+    final double intercept = srcMem.getDouble(offsetBytes);
+    offsetBytes += Double.BYTES;
+
     final double[] mean = new double[d];
     srcMem.getDoubleArray(offsetBytes, mean, 0, d);
     offsetBytes += (long) d * Double.BYTES;
@@ -163,7 +177,7 @@
     final double[] M2 = new double[d];
     srcMem.getDoubleArray(offsetBytes, M2, 0, d);
 
-    return new VectorNormalizer(d, n, mean, M2);
+    return new VectorNormalizer(d, n, mean, M2, intercept);
   }
 
   /**
@@ -178,7 +192,7 @@
         ? MatrixFamily.VECTORNORMALIZER.getMinPreLongs()
         : MatrixFamily.VECTORNORMALIZER.getMaxPreLongs();
 
-    final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : 2 * d_ * Double.BYTES);
+    final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : (1 + 2 * d_) * Double.BYTES);
     final byte[] outArr = new byte[outBytes];
     final WritableMemory memOut = WritableMemory.wrap(outArr);
     final Object memObj = memOut.getArray();
@@ -193,6 +207,8 @@
     if (!empty) {
       insertN(memObj, memAddr, n_);
       long offset = (long) preLongs * Long.BYTES;
+      memOut.putDouble(offset, intercept_);
+      offset += Double.BYTES;
       memOut.putDoubleArray(offset, mean_, 0, d_);
       offset += (long) d_ * Double.BYTES;
       memOut.putDoubleArray(offset, M2_, 0, d_);
@@ -242,6 +258,17 @@
   }
 
   /**
+   * Returns the mean of the target value, aka the intercept
+   * @return Mean of the target value
+   */
+  public double getIntercept() {
+    if (n_ == 0)
+      return Double.NaN;
+    else
+      return intercept_;
+  }
+
+  /**
    * Returns the sample variance array represented in this object. Returns an array of NaN if N = 0 and an
    * array of zeros if N = 1.
    * @return The sample variance array represented in this object
@@ -287,7 +314,7 @@
     }
   }
 
-  public void update(double[] x) {
+  public void update(final double[] x, final double target) {
     if (x == null)
       return;
 
@@ -302,17 +329,18 @@
       double d2 = x[i] - mean_[i];  // x_i - newMean_i
       M2_[i] += d1 * d2;
     }
+
+    double delta = target - intercept_;
+    intercept_ += delta / n_;
   }
 
-  public void merge(VectorNormalizer other) {
-    if (other == null)
-      return;
-
+  public void merge(@NonNull final VectorNormalizer other) {
     if (other.d_ != d_)
       throw new IllegalArgumentException("Input VectorNormalizer must have d= " + d_ + ". Found: " + other.d_);
 
     long combinedN = n_ + other.n_;
     double varCountScalar = (n_ * other.n_) / (double) combinedN; // n_A * n_B / (n_A + n_B)
+    intercept_ = ((n_ * intercept_) + (other.n_ * other.intercept_)) / combinedN;
     for (int i = 0; i < d_; ++i) {
       double meanDiff = other.mean_[i] - mean_[i];
       mean_[i] = ((n_ * mean_[i]) + (other.n_ * other.mean_[i])) / combinedN;
@@ -325,7 +353,7 @@
     if (n_ == 0) {
       return MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES;
     } else {
-      return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + (2 * d_ * Double.BYTES);
+      return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + ((1 + 2 * d_) * Double.BYTES);
     }
   }
 
diff --git a/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
index 94da914..5d32560 100644
--- a/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
+++ b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
@@ -1,2 +1,159 @@
-package org.apache.datasketches.vector.regression;public class RidgeRegressionTest {
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import static org.testng.Assert.assertEquals;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+
+import org.apache.datasketches.vector.matrix.Matrix;
+import org.testng.annotations.Test;
+
+public class RidgeRegressionTest {
+
+  @Test
+  public void normalize() {
+    final int nRows = 5;
+    final int nCols = 2;
+    final Matrix m = Matrix.builder().build(nRows, nCols);
+    m.setElement(0, 0, 1);
+    m.setElement(1, 0, 2);
+    m.setElement(2, 0, 3);
+    m.setElement(3, 0, 4);
+    m.setElement(4, 0, 5);
+    m.setElement(0, 1, 10);
+    m.setElement(1, 1, 20);
+    m.setElement(2, 1, 30);
+    m.setElement(3, 1, 40);
+    m.setElement(4, 1, 50);
+
+    final double[] targets = new double[] {-1, 1, 0, -1, 0.5};
+
+    //RidgeRegression rr = new RidgeRegression(5, 1.0, true);
+    //rr.fit(m, targets);
+  }
+
+  @Test
+  public void basicExactRegression() {
+    final int nRows = 5;
+    final int nCols = 2;
+    final Matrix m = Matrix.builder().build(nRows, nCols);
+    m.setElement(0, 0, 2);
+    m.setElement(1, 0, 3);
+    m.setElement(2, 0, 5);
+    m.setElement(3, 0, 7);
+    m.setElement(4, 0, 9);
+    m.setColumn(1, new double[]{0,0,0,0,0});
+    final double[] targets = new double[] {4, 5, 7, 10, 15};
+
+    RidgeRegression rr = new RidgeRegression(5, 0.0, false);
+    rr.fit(m, targets, true);
+    System.out.println("Weights:");
+    for (int i = 0; i < nCols; ++i) {
+      System.out.println("\t" + i + ":\t" + rr.getWeights()[i]);
+    }
+    //System.out.println("Slope: " + rr.getWeights()[0]);
+    System.out.println("Intercept: " + rr.getIntercept());
+  }
+
+  @Test
+  public void YearDataTest() {
+    final int nTrain = 16000;
+    final int nValid = 4000;
+    final int nTest = 5000;
+    final String path = "/Users/jmalkin/projects/FrequentDirectionsRidgeRegression/notebooks/SongPredictions/";
+    //Matrix fullTrain = loadTSVData(path + "years_train.tsv", nTrain);
+    //Matrix fullTest = loadTSVData(path + "years_test.tsv", nTest);
+    Matrix fullTrain = loadTSVData(path + "years_train.out", nTrain);
+    Matrix fullValid = loadTSVData(path + "years_valid.out", nValid);
+    Matrix fullTest = loadTSVData(path + "years_test.out", nTest);
+
+    final int d = (int) fullTrain.getNumColumns() - 1;
+    assertEquals(d, fullTest.getNumColumns() - 1);
+    assertEquals(nTrain, fullTrain.getNumRows());
+    assertEquals(nValid, fullValid.getNumRows());
+    assertEquals(nTest, fullTest.getNumRows());
+
+    // last column is targets
+    double[] yTrain = fullTrain.getColumn(d);
+    double[] yValid = fullValid.getColumn(d);
+    double[] yTest = fullTest.getColumn(d);
+
+    // grab the rest as training sets
+    Matrix xTrain = Matrix.builder().build(nTrain, d);
+    Matrix xValid = Matrix.builder().build(nValid, d);
+    Matrix xTest = Matrix.builder().build(nTest, d);
+    for (int i = 0; i < d; ++i) {
+      xTrain.setColumn(i, fullTrain.getColumn(i));
+      xValid.setColumn(i, fullValid.getColumn(i));
+      xTest.setColumn(i, fullTest.getColumn(i));
+    }
+
+    RidgeRegression rr = new RidgeRegression(256, 10000.0, false);
+    double error = rr.fit(xTrain, yTrain, true);
+    System.out.print("[");
+    for (final double w : rr.getWeights()) {
+      System.out.print(w + "\t");
+    }
+    System.out.println("]");
+    System.out.println("Intercept: " + rr.getIntercept());
+
+    // (needlessly) computed as part of fit
+    System.out.println("Train error: " + error);
+
+    double[] pred = rr.predict(xValid);
+    error = rr.getError(pred, yValid);
+    System.out.println("Validation error: " + error);
+
+    pred = rr.predict(xTest);
+    error = rr.getError(pred, yTest);
+    System.out.println("Test error: " + error);
+  }
+
+  Matrix loadTSVData(final String inputFile, final int nRows) {
+    Matrix data = null;
+    int row = 0;
+    try (BufferedReader br = new BufferedReader(new FileReader(inputFile))) {
+      String line;
+      while ((line = br.readLine()) != null) {
+        String[] strValues = line.split("\t");
+        double[] values = new double[strValues.length];
+
+        for (int d = 0; d < strValues.length; ++d)
+          values[d] = Double.parseDouble(strValues[d]);
+
+        if (data == null) {
+          data = Matrix.builder().build(nRows, values.length);
+        }
+        data.setRow(row, values);
+        ++row;
+      }
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+
+    assertEquals(row, nRows);
+    return data;
+  }
+
+
 }
diff --git a/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
index e6a7807..6c55aa1 100644
--- a/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
+++ b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
@@ -34,7 +34,6 @@
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
 
-import javax.imageio.plugins.jpeg.JPEGImageWriteParam;
 import java.util.concurrent.ThreadLocalRandom;
 
 import org.apache.datasketches.memory.Memory;
@@ -88,7 +87,8 @@
     final VectorNormalizer vn = new VectorNormalizer(d);
 
     final double[] input = {-1, 0, 0.5};
-    vn.update(input);
+    final double target = 0.3;
+    vn.update(input, target);
     assertEquals(vn.getN(), 1);
     assertFalse(vn.isEmpty());
 
@@ -107,6 +107,8 @@
       assertEquals(sampleVar[i], 0.0);
       assertEquals(popVar[i], 0.0);
     }
+    // intercept shoudl equal target
+    assertEquals(vn.getIntercept(), target);
   }
 
   @Test
@@ -123,7 +125,8 @@
       input[0] = rand.nextGaussian();      // mean = 0.0, var = 1.0
       input[1] = rand.nextDouble() * 2.0;  // mean = 1.0, var = (2-0)^2/12 = 1/3
       input[2] = rand.nextDouble() - 0.5;  // mean = 0.0, var = (1-0)^2/12
-      vn.update(input);
+      double target = rand.nextGaussian() - 1.0; // mean = -1.0
+      vn.update(input, target);
     }
     assertFalse(vn.isEmpty());
 
@@ -132,6 +135,7 @@
     assertEquals(mean[0], 0.0, tol);
     assertEquals(mean[1], 1.0, tol);
     assertEquals(mean[2], 0.0, tol);
+    assertEquals(vn.getIntercept(), -1.0, tol);
 
     // n is large enough that sample vs population variance won't matter for testing
     final double[] sampleVar = vn.getSampleVariance();
@@ -165,15 +169,18 @@
     // data expectations:
     // dimension 0: zero-mean, unit-variance Gaussian, even after merging
     // dimension 1: U[0,2] + U[2,4) -> U[0,4), so mean = 2.0 and var = 4^2/12=4/3
+    // target: N(-1,1) + N(1,1), so mean = 0.0 and variance unmeasured
     final double[] input = new double[d];
     for (int i = 0; i < n; ++i) {
       input[0] = rand.nextGaussian();
       input[1] = (rand.nextDouble() * 2.0) + 2.0;
-      vn1.update(input);
+      double target = rand.nextGaussian() - 1.0;
+      vn1.update(input, target);
 
       input[0] = rand.nextGaussian();
       input[1] = rand.nextDouble() * 2.0;
-      vn2.update(input);
+      target = rand.nextGaussian() + 1.0;
+      vn2.update(input, target);
     }
 
     vn1.merge(vn2);
@@ -182,6 +189,7 @@
     final double[] mean = vn1.getMean();
     assertEquals(mean[0], 0.0, tol);
     assertEquals(mean[1], 2.0, tol);
+    assertEquals(vn1.getIntercept(), 0.0, tol);
 
     // n is large enough that sample vs population variance won't matter for testing
     final double[] sampleVar = vn1.getSampleVariance();
@@ -200,15 +208,12 @@
 
     final double[] input = new double[d];
     for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
-    vn.update(input);
-    assertEquals(vn.getN(), 1);
-
-    vn.update(null);
+    vn.update(input, 0.0);
     assertEquals(vn.getN(), 1);
 
     try {
       final double[] badInput = {1.0};
-      vn.update(badInput);
+      vn.update(badInput, 0.0);
       fail();
     } catch (IllegalArgumentException e) {
       // expected
@@ -223,10 +228,7 @@
 
     double[] input = new double[d];
     for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
-    vn1.update(input);
-    assertEquals(vn1.getN(), 1);
-
-    vn1.merge(null);
+    vn1.update(input, 1.0);
     assertEquals(vn1.getN(), 1);
 
     // update with a non-empty VN with a different value of d
@@ -234,7 +236,7 @@
     final VectorNormalizer vn2 = new VectorNormalizer(d2);
     input = new double[d2];
     for (int i = 0; i < d2; ++i) { input[i] = 1.0 * i; }
-    vn2.update(input);
+    vn2.update(input, 2.0);
     assertEquals(vn2.getN(), 1);
 
     try {
@@ -258,7 +260,7 @@
       for (int j = 0; j < d; ++j) {
         input[j] = rand.nextDouble();
       }
-      vn.update(input);
+      vn.update(input, rand.nextGaussian());
     }
 
     final VectorNormalizer vnCopy = new VectorNormalizer(vn);
@@ -294,7 +296,7 @@
       for (int j = 0; j < d; ++j) {
         input[j] = rand.nextGaussian();
       }
-      vn.update(input);
+      vn.update(input, rand.nextDouble());
     }
 
     outBytes = vn.toByteArray();
@@ -304,6 +306,7 @@
     assertFalse(rebuilt.isEmpty());
     assertEquals(vn.getD(), rebuilt.getD());
     assertEquals(vn.getN(), rebuilt.getN());
+    assertEquals(vn.getIntercept(), rebuilt.getIntercept());
 
     final double[] originalMean = vn.getMean();
     final double[] rebuiltMean = vn.getMean();
@@ -413,7 +416,7 @@
       for (int j = 0; j < d; ++j) {
         input[j] = rand.nextDouble();
       }
-      vn.update(input);
+      vn.update(input, rand.nextDouble());
     }
     assertFalse(vn.isEmpty());