blob: 5f3dbba07e988f6d9ea1380183dc505bb594f249 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.fitting.leastsquares;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.util.MathArrays;
/**
* This class gives access to the statistical reference datasets provided by the
* NIST (available
* <a href="http://www.itl.nist.gov/div898/strd/general/dataarchive.html">here</a>).
* Instances of this class can be created by invocation of the
* {@link StatisticalReferenceDatasetFactory}.
*/
public abstract class StatisticalReferenceDataset {
/** The name of this dataset. */
private final String name;
/** The total number of observations (data points). */
private final int numObservations;
/** The total number of parameters. */
private final int numParameters;
/** The total number of starting points for the optimizations. */
private final int numStartingPoints;
/** The values of the predictor. */
private final double[] x;
/** The values of the response. */
private final double[] y;
/**
* The starting values. {@code startingValues[j][i]} is the value of the
* {@code i}-th parameter in the {@code j}-th set of starting values.
*/
private final double[][] startingValues;
/** The certified values of the parameters. */
private final double[] a;
/** The certified values of the standard deviation of the parameters. */
private final double[] sigA;
/** The certified value of the residual sum of squares. */
private double residualSumOfSquares;
/** The least-squares problem. */
private final LeastSquaresProblem problem;
/**
* Creates a new instance of this class from the specified data file. The
* file must follow the StRD format.
*
* @param in the data file
* @throws IOException if an I/O error occurs
*/
public StatisticalReferenceDataset(final BufferedReader in)
throws IOException {
final ArrayList<String> lines = new ArrayList<String>();
for (String line = in.readLine(); line != null; line = in.readLine()) {
lines.add(line);
}
int[] index = findLineNumbers("Data", lines);
if (index == null) {
throw new AssertionError("could not find line indices for data");
}
this.numObservations = index[1] - index[0] + 1;
this.x = new double[this.numObservations];
this.y = new double[this.numObservations];
for (int i = 0; i < this.numObservations; i++) {
final String line = lines.get(index[0] + i - 1);
final String[] tokens = line.trim().split(" ++");
// Data columns are in reverse order!!!
this.y[i] = Double.parseDouble(tokens[0]);
this.x[i] = Double.parseDouble(tokens[1]);
}
index = findLineNumbers("Starting Values", lines);
if (index == null) {
throw new AssertionError(
"could not find line indices for starting values");
}
this.numParameters = index[1] - index[0] + 1;
double[][] start = null;
this.a = new double[numParameters];
this.sigA = new double[numParameters];
for (int i = 0; i < numParameters; i++) {
final String line = lines.get(index[0] + i - 1);
final String[] tokens = line.trim().split(" ++");
if (start == null) {
start = new double[tokens.length - 4][numParameters];
}
for (int j = 2; j < tokens.length - 2; j++) {
start[j - 2][i] = Double.parseDouble(tokens[j]);
}
this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
}
if (start == null) {
throw new IOException("could not find starting values");
}
this.numStartingPoints = start.length;
this.startingValues = start;
double dummyDouble = Double.NaN;
String dummyString = null;
for (String line : lines) {
if (line.contains("Dataset Name:")) {
dummyString = line
.substring(line.indexOf("Dataset Name:") + 13,
line.indexOf("(")).trim();
}
if (line.contains("Residual Sum of Squares")) {
final String[] tokens = line.split(" ++");
dummyDouble = Double.parseDouble(tokens[4].trim());
}
}
if (Double.isNaN(dummyDouble)) {
throw new IOException(
"could not find certified value of residual sum of squares");
}
this.residualSumOfSquares = dummyDouble;
if (dummyString == null) {
throw new IOException("could not find dataset name");
}
this.name = dummyString;
this.problem = new LeastSquaresProblem();
}
class LeastSquaresProblem {
public MultivariateVectorFunction getModelFunction() {
return new MultivariateVectorFunction() {
public double[] value(final double[] a) {
final int n = getNumObservations();
final double[] yhat = new double[n];
for (int i = 0; i < n; i++) {
yhat[i] = getModelValue(getX(i), a);
}
return yhat;
}
};
}
public MultivariateMatrixFunction getModelFunctionJacobian() {
return new MultivariateMatrixFunction() {
public double[][] value(final double[] a)
throws IllegalArgumentException {
final int n = getNumObservations();
final double[][] j = new double[n][];
for (int i = 0; i < n; i++) {
j[i] = getModelDerivatives(getX(i), a);
}
return j;
}
};
}
}
/**
* Returns the name of this dataset.
*
* @return the name of the dataset
*/
public String getName() {
return name;
}
/**
* Returns the total number of observations (data points).
*
* @return the number of observations
*/
public int getNumObservations() {
return numObservations;
}
/**
* Returns a copy of the data arrays. The data is laid out as follows <li>
* {@code data[0][i] = x[i]},</li> <li>{@code data[1][i] = y[i]},</li>
*
* @return the array of data points.
*/
public double[][] getData() {
return new double[][] {
MathArrays.copyOf(x), MathArrays.copyOf(y)
};
}
/**
* Returns the x-value of the {@code i}-th data point.
*
* @param i the index of the data point
* @return the x-value
*/
public double getX(final int i) {
return x[i];
}
/**
* Returns the y-value of the {@code i}-th data point.
*
* @param i the index of the data point
* @return the y-value
*/
public double getY(final int i) {
return y[i];
}
/**
* Returns the total number of parameters.
*
* @return the number of parameters
*/
public int getNumParameters() {
return numParameters;
}
/**
* Returns the certified values of the paramters.
*
* @return the values of the parameters
*/
public double[] getParameters() {
return MathArrays.copyOf(a);
}
/**
* Returns the certified value of the {@code i}-th parameter.
*
* @param i the index of the parameter
* @return the value of the parameter
*/
public double getParameter(final int i) {
return a[i];
}
/**
* Reurns the certified values of the standard deviations of the parameters.
*
* @return the standard deviations of the parameters
*/
public double[] getParametersStandardDeviations() {
return MathArrays.copyOf(sigA);
}
/**
* Returns the certified value of the standard deviation of the {@code i}-th
* parameter.
*
* @param i the index of the parameter
* @return the standard deviation of the parameter
*/
public double getParameterStandardDeviation(final int i) {
return sigA[i];
}
/**
* Returns the certified value of the residual sum of squares.
*
* @return the residual sum of squares
*/
public double getResidualSumOfSquares() {
return residualSumOfSquares;
}
/**
* Returns the total number of starting points (initial guesses for the
* optimization process).
*
* @return the number of starting points
*/
public int getNumStartingPoints() {
return numStartingPoints;
}
/**
* Returns the {@code i}-th set of initial values of the parameters.
*
* @param i the index of the starting point
* @return the starting point
*/
public double[] getStartingPoint(final int i) {
return MathArrays.copyOf(startingValues[i]);
}
/**
* Returns the least-squares problem corresponding to fitting the model to
* the specified data.
*
* @return the least-squares problem
*/
public LeastSquaresProblem getLeastSquaresProblem() {
return problem;
}
/**
* Returns the value of the model for the specified values of the predictor
* variable and the parameters.
*
* @param x the predictor variable
* @param a the parameters
* @return the value of the model
*/
public abstract double getModelValue(final double x, final double[] a);
/**
* Returns the values of the partial derivatives of the model with respect
* to the parameters.
*
* @param x the predictor variable
* @param a the parameters
* @return the partial derivatives
*/
public abstract double[] getModelDerivatives(final double x,
final double[] a);
/**
* <p>
* Parses the specified text lines, and extracts the indices of the first
* and last lines of the data defined by the specified {@code key}. This key
* must be one of
* </p>
* <ul>
* <li>{@code "Starting Values"},</li>
* <li>{@code "Certified Values"},</li>
* <li>{@code "Data"}.</li>
* </ul>
* <p>
* In the NIST data files, the line indices are separated by the keywords
* {@code "lines"} and {@code "to"}.
* </p>
*
* @param lines the line of text to be parsed
* @return an array of two {@code int}s. First value is the index of the
* first line, second value is the index of the last line.
* {@code null} if the line could not be parsed.
*/
private static int[] findLineNumbers(final String key,
final Iterable<String> lines) {
for (String text : lines) {
boolean flag = text.contains(key) && text.contains("lines") &&
text.contains("to") && text.contains(")");
if (flag) {
final int[] numbers = new int[2];
final String from = text.substring(text.indexOf("lines") + 5,
text.indexOf("to"));
numbers[0] = Integer.parseInt(from.trim());
final String to = text.substring(text.indexOf("to") + 2,
text.indexOf(")"));
numbers[1] = Integer.parseInt(to.trim());
return numbers;
}
}
return null;
}
}