blob: 5d325603188990573cdc338d292251a9f10fa1db [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.datasketches.vector.regression;
import static org.testng.Assert.assertEquals;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import org.apache.datasketches.vector.matrix.Matrix;
import org.testng.annotations.Test;
public class RidgeRegressionTest {
@Test
public void normalize() {
final int nRows = 5;
final int nCols = 2;
final Matrix m = Matrix.builder().build(nRows, nCols);
m.setElement(0, 0, 1);
m.setElement(1, 0, 2);
m.setElement(2, 0, 3);
m.setElement(3, 0, 4);
m.setElement(4, 0, 5);
m.setElement(0, 1, 10);
m.setElement(1, 1, 20);
m.setElement(2, 1, 30);
m.setElement(3, 1, 40);
m.setElement(4, 1, 50);
final double[] targets = new double[] {-1, 1, 0, -1, 0.5};
//RidgeRegression rr = new RidgeRegression(5, 1.0, true);
//rr.fit(m, targets);
}
@Test
public void basicExactRegression() {
final int nRows = 5;
final int nCols = 2;
final Matrix m = Matrix.builder().build(nRows, nCols);
m.setElement(0, 0, 2);
m.setElement(1, 0, 3);
m.setElement(2, 0, 5);
m.setElement(3, 0, 7);
m.setElement(4, 0, 9);
m.setColumn(1, new double[]{0,0,0,0,0});
final double[] targets = new double[] {4, 5, 7, 10, 15};
RidgeRegression rr = new RidgeRegression(5, 0.0, false);
rr.fit(m, targets, true);
System.out.println("Weights:");
for (int i = 0; i < nCols; ++i) {
System.out.println("\t" + i + ":\t" + rr.getWeights()[i]);
}
//System.out.println("Slope: " + rr.getWeights()[0]);
System.out.println("Intercept: " + rr.getIntercept());
}
@Test
public void YearDataTest() {
final int nTrain = 16000;
final int nValid = 4000;
final int nTest = 5000;
final String path = "/Users/jmalkin/projects/FrequentDirectionsRidgeRegression/notebooks/SongPredictions/";
//Matrix fullTrain = loadTSVData(path + "years_train.tsv", nTrain);
//Matrix fullTest = loadTSVData(path + "years_test.tsv", nTest);
Matrix fullTrain = loadTSVData(path + "years_train.out", nTrain);
Matrix fullValid = loadTSVData(path + "years_valid.out", nValid);
Matrix fullTest = loadTSVData(path + "years_test.out", nTest);
final int d = (int) fullTrain.getNumColumns() - 1;
assertEquals(d, fullTest.getNumColumns() - 1);
assertEquals(nTrain, fullTrain.getNumRows());
assertEquals(nValid, fullValid.getNumRows());
assertEquals(nTest, fullTest.getNumRows());
// last column is targets
double[] yTrain = fullTrain.getColumn(d);
double[] yValid = fullValid.getColumn(d);
double[] yTest = fullTest.getColumn(d);
// grab the rest as training sets
Matrix xTrain = Matrix.builder().build(nTrain, d);
Matrix xValid = Matrix.builder().build(nValid, d);
Matrix xTest = Matrix.builder().build(nTest, d);
for (int i = 0; i < d; ++i) {
xTrain.setColumn(i, fullTrain.getColumn(i));
xValid.setColumn(i, fullValid.getColumn(i));
xTest.setColumn(i, fullTest.getColumn(i));
}
RidgeRegression rr = new RidgeRegression(256, 10000.0, false);
double error = rr.fit(xTrain, yTrain, true);
System.out.print("[");
for (final double w : rr.getWeights()) {
System.out.print(w + "\t");
}
System.out.println("]");
System.out.println("Intercept: " + rr.getIntercept());
// (needlessly) computed as part of fit
System.out.println("Train error: " + error);
double[] pred = rr.predict(xValid);
error = rr.getError(pred, yValid);
System.out.println("Validation error: " + error);
pred = rr.predict(xTest);
error = rr.getError(pred, yTest);
System.out.println("Test error: " + error);
}
Matrix loadTSVData(final String inputFile, final int nRows) {
Matrix data = null;
int row = 0;
try (BufferedReader br = new BufferedReader(new FileReader(inputFile))) {
String line;
while ((line = br.readLine()) != null) {
String[] strValues = line.split("\t");
double[] values = new double[strValues.length];
for (int d = 0; d < strValues.length; ++d)
values[d] = Double.parseDouble(strValues[d]);
if (data == null) {
data = Matrix.builder().build(nRows, values.length);
}
data.setRow(row, values);
++row;
}
} catch (IOException e) {
throw new RuntimeException(e);
}
assertEquals(row, nRows);
return data;
}
}