blob: 012c913ef0f8851c25b4a6a97e933356196f43c6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.fitting;
import java.util.Collections;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.ArrayList;
import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
/**
* Fits points to a user-defined {@link ParametricUnivariateFunction function}.
*
* @since 3.4
*/
public class SimpleCurveFitter extends AbstractCurveFitter {
/** Function to fit. */
private final ParametricUnivariateFunction function;
/** Initial guess for the parameters. */
private final double[] initialGuess;
/** Parameter guesser. */
private final ParameterGuesser guesser;
/** Maximum number of iterations of the optimization algorithm. */
private final int maxIter;
/**
* Constructor used by the factory methods.
*
* @param function Function to fit.
* @param initialGuess Initial guess. Cannot be {@code null}. Its length must
* be consistent with the number of parameters of the {@code function} to fit.
* @param guesser Method for providing an initial guess (if {@code initialGuess}
* is {@code null}).
* @param maxIter Maximum number of iterations of the optimization algorithm.
*/
protected SimpleCurveFitter(ParametricUnivariateFunction function,
double[] initialGuess,
ParameterGuesser guesser,
int maxIter) {
this.function = function;
this.initialGuess = initialGuess;
this.guesser = guesser;
this.maxIter = maxIter;
}
/**
* Creates a curve fitter.
* The maximum number of iterations of the optimization algorithm is set
* to {@link Integer#MAX_VALUE}.
*
* @param f Function to fit.
* @param start Initial guess for the parameters. Cannot be {@code null}.
* Its length must be consistent with the number of parameters of the
* function to fit.
* @return a curve fitter.
*
* @see #withStartPoint(double[])
* @see #withMaxIterations(int)
*/
public static SimpleCurveFitter create(ParametricUnivariateFunction f,
double[] start) {
return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
}
/**
* Creates a curve fitter.
* The maximum number of iterations of the optimization algorithm is set
* to {@link Integer#MAX_VALUE}.
*
* @param f Function to fit.
* @param guesser Method for providing an initial guess.
* @return a curve fitter.
*
* @see #withStartPoint(double[])
* @see #withMaxIterations(int)
*/
public static SimpleCurveFitter create(ParametricUnivariateFunction f,
ParameterGuesser guesser) {
return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
}
/**
* Configure the start point (initial guess).
* @param newStart new start point (initial guess)
* @return a new instance.
*/
public SimpleCurveFitter withStartPoint(double[] newStart) {
return new SimpleCurveFitter(function,
newStart.clone(),
null,
maxIter);
}
/**
* Configure the maximum number of iterations.
* @param newMaxIter maximum number of iterations
* @return a new instance.
*/
public SimpleCurveFitter withMaxIterations(int newMaxIter) {
return new SimpleCurveFitter(function,
initialGuess,
guesser,
newMaxIter);
}
/** {@inheritDoc} */
@Override
protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
// Prepare least-squares problem.
final int len = observations.size();
final double[] target = new double[len];
final double[] weights = new double[len];
int count = 0;
for (WeightedObservedPoint obs : observations) {
target[count] = obs.getY();
weights[count] = obs.getWeight();
++count;
}
final AbstractCurveFitter.TheoreticalValuesFunction model
= new AbstractCurveFitter.TheoreticalValuesFunction(function,
observations);
final double[] startPoint = initialGuess != null ?
initialGuess :
// Compute estimation.
guesser.guess(observations);
// Create an optimizer for fitting the curve to the observed points.
return new LeastSquaresBuilder().
maxEvaluations(Integer.MAX_VALUE).
maxIterations(maxIter).
start(startPoint).
target(target).
weight(new DiagonalMatrix(weights)).
model(model.getModelFunction(), model.getModelFunctionJacobian()).
build();
}
/**
* Guesses the parameters.
*/
public abstract static class ParameterGuesser {
/** Comparator. */
private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
/** {@inheritDoc} */
@Override
public int compare(WeightedObservedPoint p1,
WeightedObservedPoint p2) {
if (p1 == null && p2 == null) {
return 0;
}
if (p1 == null) {
return -1;
}
if (p2 == null) {
return 1;
}
int comp = Double.compare(p1.getX(), p2.getX());
if (comp != 0) {
return comp;
}
comp = Double.compare(p1.getY(), p2.getY());
if (comp != 0) {
return comp;
}
return Double.compare(p1.getWeight(), p2.getWeight());
}
};
/**
* Computes an estimation of the parameters.
*
* @param obs Observations.
* @return the guessed parameters.
*/
public abstract double[] guess(Collection<WeightedObservedPoint> obs);
/**
* Sort the observations.
*
* @param unsorted Input observations.
* @return the input observations, sorted.
*/
protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
Collections.sort(observations, CMP);
return observations;
}
/**
* Finds index of point in specified points with the largest Y.
*
* @param points Points to search.
* @return the index in specified points array.
*/
protected int findMaxY(WeightedObservedPoint[] points) {
int maxYIdx = 0;
for (int i = 1; i < points.length; i++) {
if (points[i].getY() > points[maxYIdx].getY()) {
maxYIdx = i;
}
}
return maxYIdx;
}
/**
* Interpolates using the specified points to determine X at the
* specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start the search for
* interpolation bounds points.
* @param idxStep Index step for searching interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the value of X for the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
protected double interpolateXAtY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y) {
if (idxStep == 0) {
throw new ZeroException();
}
final WeightedObservedPoint[] twoPoints
= getInterpolationPointsForY(points, startIdx, idxStep, y);
final WeightedObservedPoint p1 = twoPoints[0];
final WeightedObservedPoint p2 = twoPoints[1];
if (p1.getY() == y) {
return p1.getX();
}
if (p2.getY() == y) {
return p2.getX();
}
return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
(p2.getY() - p1.getY()));
}
/**
* Gets the two bounding interpolation points from the specified points
* suitable for determining X at the specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start search for
* interpolation bounds points.
* @param idxStep Index step for search for interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the array containing two points suitable for determining X at
* the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y) {
if (idxStep == 0) {
throw new ZeroException();
}
for (int i = startIdx;
idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
i += idxStep) {
final WeightedObservedPoint p1 = points[i];
final WeightedObservedPoint p2 = points[i + idxStep];
if (isBetween(y, p1.getY(), p2.getY())) {
if (idxStep < 0) {
return new WeightedObservedPoint[] { p2, p1 };
} else {
return new WeightedObservedPoint[] { p1, p2 };
}
}
}
// Boundaries are replaced by dummy values because the raised
// exception is caught and the message never displayed.
// TODO: Exceptions should not be used for flow control.
throw new OutOfRangeException(y,
Double.NEGATIVE_INFINITY,
Double.POSITIVE_INFINITY);
}
/**
* Determines whether a value is between two other values.
*
* @param value Value to test whether it is between {@code boundary1}
* and {@code boundary2}.
* @param boundary1 One end of the range.
* @param boundary2 Other end of the range.
* @return {@code true} if {@code value} is between {@code boundary1} and
* {@code boundary2} (inclusive), {@code false} otherwise.
*/
private boolean isBetween(double value,
double boundary1,
double boundary2) {
return (value >= boundary1 && value <= boundary2) ||
(value >= boundary2 && value <= boundary1);
}
}
}