| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.commons.math4.legacy.fitting; |
| |
| import java.util.Collections; |
| import java.util.Collection; |
| import java.util.Comparator; |
| import java.util.List; |
| import java.util.ArrayList; |
| |
| import org.apache.commons.math4.legacy.exception.ZeroException; |
| import org.apache.commons.math4.legacy.exception.OutOfRangeException; |
| import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; |
| import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder; |
| import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem; |
| import org.apache.commons.math4.legacy.linear.DiagonalMatrix; |
| |
| /** |
| * Fits points to a user-defined {@link ParametricUnivariateFunction function}. |
| * |
| * @since 3.4 |
| */ |
| public class SimpleCurveFitter extends AbstractCurveFitter { |
| /** Function to fit. */ |
| private final ParametricUnivariateFunction function; |
| /** Initial guess for the parameters. */ |
| private final double[] initialGuess; |
| /** Parameter guesser. */ |
| private final ParameterGuesser guesser; |
| /** Maximum number of iterations of the optimization algorithm. */ |
| private final int maxIter; |
| |
| /** |
| * Constructor used by the factory methods. |
| * |
| * @param function Function to fit. |
| * @param initialGuess Initial guess. Cannot be {@code null}. Its length must |
| * be consistent with the number of parameters of the {@code function} to fit. |
| * @param guesser Method for providing an initial guess (if {@code initialGuess} |
| * is {@code null}). |
| * @param maxIter Maximum number of iterations of the optimization algorithm. |
| */ |
| protected SimpleCurveFitter(ParametricUnivariateFunction function, |
| double[] initialGuess, |
| ParameterGuesser guesser, |
| int maxIter) { |
| this.function = function; |
| this.initialGuess = initialGuess; |
| this.guesser = guesser; |
| this.maxIter = maxIter; |
| } |
| |
| /** |
| * Creates a curve fitter. |
| * The maximum number of iterations of the optimization algorithm is set |
| * to {@link Integer#MAX_VALUE}. |
| * |
| * @param f Function to fit. |
| * @param start Initial guess for the parameters. Cannot be {@code null}. |
| * Its length must be consistent with the number of parameters of the |
| * function to fit. |
| * @return a curve fitter. |
| * |
| * @see #withStartPoint(double[]) |
| * @see #withMaxIterations(int) |
| */ |
| public static SimpleCurveFitter create(ParametricUnivariateFunction f, |
| double[] start) { |
| return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE); |
| } |
| |
| /** |
| * Creates a curve fitter. |
| * The maximum number of iterations of the optimization algorithm is set |
| * to {@link Integer#MAX_VALUE}. |
| * |
| * @param f Function to fit. |
| * @param guesser Method for providing an initial guess. |
| * @return a curve fitter. |
| * |
| * @see #withStartPoint(double[]) |
| * @see #withMaxIterations(int) |
| */ |
| public static SimpleCurveFitter create(ParametricUnivariateFunction f, |
| ParameterGuesser guesser) { |
| return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE); |
| } |
| |
| /** |
| * Configure the start point (initial guess). |
| * @param newStart new start point (initial guess) |
| * @return a new instance. |
| */ |
| public SimpleCurveFitter withStartPoint(double[] newStart) { |
| return new SimpleCurveFitter(function, |
| newStart.clone(), |
| null, |
| maxIter); |
| } |
| |
| /** |
| * Configure the maximum number of iterations. |
| * @param newMaxIter maximum number of iterations |
| * @return a new instance. |
| */ |
| public SimpleCurveFitter withMaxIterations(int newMaxIter) { |
| return new SimpleCurveFitter(function, |
| initialGuess, |
| guesser, |
| newMaxIter); |
| } |
| |
| /** {@inheritDoc} */ |
| @Override |
| protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { |
| // Prepare least-squares problem. |
| final int len = observations.size(); |
| final double[] target = new double[len]; |
| final double[] weights = new double[len]; |
| |
| int count = 0; |
| for (WeightedObservedPoint obs : observations) { |
| target[count] = obs.getY(); |
| weights[count] = obs.getWeight(); |
| ++count; |
| } |
| |
| final AbstractCurveFitter.TheoreticalValuesFunction model |
| = new AbstractCurveFitter.TheoreticalValuesFunction(function, |
| observations); |
| |
| final double[] startPoint = initialGuess != null ? |
| initialGuess : |
| // Compute estimation. |
| guesser.guess(observations); |
| |
| // Create an optimizer for fitting the curve to the observed points. |
| return new LeastSquaresBuilder(). |
| maxEvaluations(Integer.MAX_VALUE). |
| maxIterations(maxIter). |
| start(startPoint). |
| target(target). |
| weight(new DiagonalMatrix(weights)). |
| model(model.getModelFunction(), model.getModelFunctionJacobian()). |
| build(); |
| } |
| |
| /** |
| * Guesses the parameters. |
| */ |
| public abstract static class ParameterGuesser { |
| /** Comparator. */ |
| private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() { |
| /** {@inheritDoc} */ |
| @Override |
| public int compare(WeightedObservedPoint p1, |
| WeightedObservedPoint p2) { |
| if (p1 == null && p2 == null) { |
| return 0; |
| } |
| if (p1 == null) { |
| return -1; |
| } |
| if (p2 == null) { |
| return 1; |
| } |
| int comp = Double.compare(p1.getX(), p2.getX()); |
| if (comp != 0) { |
| return comp; |
| } |
| comp = Double.compare(p1.getY(), p2.getY()); |
| if (comp != 0) { |
| return comp; |
| } |
| return Double.compare(p1.getWeight(), p2.getWeight()); |
| } |
| }; |
| |
| /** |
| * Computes an estimation of the parameters. |
| * |
| * @param obs Observations. |
| * @return the guessed parameters. |
| */ |
| public abstract double[] guess(Collection<WeightedObservedPoint> obs); |
| |
| /** |
| * Sort the observations. |
| * |
| * @param unsorted Input observations. |
| * @return the input observations, sorted. |
| */ |
| protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { |
| final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted); |
| Collections.sort(observations, CMP); |
| return observations; |
| } |
| |
| /** |
| * Finds index of point in specified points with the largest Y. |
| * |
| * @param points Points to search. |
| * @return the index in specified points array. |
| */ |
| protected int findMaxY(WeightedObservedPoint[] points) { |
| int maxYIdx = 0; |
| for (int i = 1; i < points.length; i++) { |
| if (points[i].getY() > points[maxYIdx].getY()) { |
| maxYIdx = i; |
| } |
| } |
| return maxYIdx; |
| } |
| |
| /** |
| * Interpolates using the specified points to determine X at the |
| * specified Y. |
| * |
| * @param points Points to use for interpolation. |
| * @param startIdx Index within points from which to start the search for |
| * interpolation bounds points. |
| * @param idxStep Index step for searching interpolation bounds points. |
| * @param y Y value for which X should be determined. |
| * @return the value of X for the specified Y. |
| * @throws ZeroException if {@code idxStep} is 0. |
| * @throws OutOfRangeException if specified {@code y} is not within the |
| * range of the specified {@code points}. |
| */ |
| protected double interpolateXAtY(WeightedObservedPoint[] points, |
| int startIdx, |
| int idxStep, |
| double y) { |
| if (idxStep == 0) { |
| throw new ZeroException(); |
| } |
| final WeightedObservedPoint[] twoPoints |
| = getInterpolationPointsForY(points, startIdx, idxStep, y); |
| final WeightedObservedPoint p1 = twoPoints[0]; |
| final WeightedObservedPoint p2 = twoPoints[1]; |
| if (p1.getY() == y) { |
| return p1.getX(); |
| } |
| if (p2.getY() == y) { |
| return p2.getX(); |
| } |
| return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / |
| (p2.getY() - p1.getY())); |
| } |
| |
| /** |
| * Gets the two bounding interpolation points from the specified points |
| * suitable for determining X at the specified Y. |
| * |
| * @param points Points to use for interpolation. |
| * @param startIdx Index within points from which to start search for |
| * interpolation bounds points. |
| * @param idxStep Index step for search for interpolation bounds points. |
| * @param y Y value for which X should be determined. |
| * @return the array containing two points suitable for determining X at |
| * the specified Y. |
| * @throws ZeroException if {@code idxStep} is 0. |
| * @throws OutOfRangeException if specified {@code y} is not within the |
| * range of the specified {@code points}. |
| */ |
| private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, |
| int startIdx, |
| int idxStep, |
| double y) { |
| if (idxStep == 0) { |
| throw new ZeroException(); |
| } |
| for (int i = startIdx; |
| idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; |
| i += idxStep) { |
| final WeightedObservedPoint p1 = points[i]; |
| final WeightedObservedPoint p2 = points[i + idxStep]; |
| if (isBetween(y, p1.getY(), p2.getY())) { |
| if (idxStep < 0) { |
| return new WeightedObservedPoint[] { p2, p1 }; |
| } else { |
| return new WeightedObservedPoint[] { p1, p2 }; |
| } |
| } |
| } |
| |
| // Boundaries are replaced by dummy values because the raised |
| // exception is caught and the message never displayed. |
| // TODO: Exceptions should not be used for flow control. |
| throw new OutOfRangeException(y, |
| Double.NEGATIVE_INFINITY, |
| Double.POSITIVE_INFINITY); |
| } |
| |
| /** |
| * Determines whether a value is between two other values. |
| * |
| * @param value Value to test whether it is between {@code boundary1} |
| * and {@code boundary2}. |
| * @param boundary1 One end of the range. |
| * @param boundary2 Other end of the range. |
| * @return {@code true} if {@code value} is between {@code boundary1} and |
| * {@code boundary2} (inclusive), {@code false} otherwise. |
| */ |
| private boolean isBetween(double value, |
| double boundary1, |
| double boundary2) { |
| return (value >= boundary1 && value <= boundary2) || |
| (value >= boundary2 && value <= boundary1); |
| } |
| } |
| } |