| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math3.optimization; |
| |
| import java.util.Arrays; |
| import java.util.Comparator; |
| |
| import org.apache.commons.math3.analysis.MultivariateVectorFunction; |
| import org.apache.commons.math3.exception.ConvergenceException; |
| import org.apache.commons.math3.exception.MathIllegalStateException; |
| import org.apache.commons.math3.exception.NotStrictlyPositiveException; |
| import org.apache.commons.math3.exception.NullArgumentException; |
| import org.apache.commons.math3.exception.util.LocalizedFormats; |
| import org.apache.commons.math3.random.RandomVectorGenerator; |
| |
| /** |
| * Base class for all implementations of a multi-start optimizer. |
| * |
| * This interface is mainly intended to enforce the internal coherence of |
| * Commons-Math. Users of the API are advised to base their code on |
| * {@link DifferentiableMultivariateVectorMultiStartOptimizer}. |
| * |
| * @param <FUNC> Type of the objective function to be optimized. |
| * |
| * @deprecated As of 3.1 (to be removed in 4.0). |
| * @since 3.0 |
| */ |
| @Deprecated |
| public class BaseMultivariateVectorMultiStartOptimizer<FUNC extends MultivariateVectorFunction> |
| implements BaseMultivariateVectorOptimizer<FUNC> { |
| /** Underlying classical optimizer. */ |
| private final BaseMultivariateVectorOptimizer<FUNC> optimizer; |
| /** Maximal number of evaluations allowed. */ |
| private int maxEvaluations; |
| /** Number of evaluations already performed for all starts. */ |
| private int totalEvaluations; |
| /** Number of starts to go. */ |
| private int starts; |
| /** Random generator for multi-start. */ |
| private RandomVectorGenerator generator; |
| /** Found optima. */ |
| private PointVectorValuePair[] optima; |
| |
| /** |
| * Create a multi-start optimizer from a single-start optimizer. |
| * |
| * @param optimizer Single-start optimizer to wrap. |
| * @param starts Number of starts to perform. If {@code starts == 1}, |
| * the {@link #optimize(int,MultivariateVectorFunction,double[],double[],double[]) |
| * optimize} will return the same solution as {@code optimizer} would. |
| * @param generator Random vector generator to use for restarts. |
| * @throws NullArgumentException if {@code optimizer} or {@code generator} |
| * is {@code null}. |
| * @throws NotStrictlyPositiveException if {@code starts < 1}. |
| */ |
| protected BaseMultivariateVectorMultiStartOptimizer(final BaseMultivariateVectorOptimizer<FUNC> optimizer, |
| final int starts, |
| final RandomVectorGenerator generator) { |
| if (optimizer == null || |
| generator == null) { |
| throw new NullArgumentException(); |
| } |
| if (starts < 1) { |
| throw new NotStrictlyPositiveException(starts); |
| } |
| |
| this.optimizer = optimizer; |
| this.starts = starts; |
| this.generator = generator; |
| } |
| |
| /** |
| * Get all the optima found during the last call to {@link |
| * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize}. |
| * The optimizer stores all the optima found during a set of |
| * restarts. The {@link #optimize(int,MultivariateVectorFunction,double[],double[],double[]) |
| * optimize} method returns the best point only. This method |
| * returns all the points found at the end of each starts, including |
| * the best one already returned by the {@link |
| * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method. |
| * <br/> |
| * The returned array as one element for each start as specified |
| * in the constructor. It is ordered with the results from the |
| * runs that did converge first, sorted from best to worst |
| * objective value (i.e. in ascending order if minimizing and in |
| * descending order if maximizing), followed by and null elements |
| * corresponding to the runs that did not converge. This means all |
| * elements will be null if the {@link |
| * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method did |
| * throw a {@link ConvergenceException}). This also means that if |
| * the first element is not {@code null}, it is the best point found |
| * across all starts. |
| * |
| * @return array containing the optima |
| * @throws MathIllegalStateException if {@link |
| * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} has not been |
| * called. |
| */ |
| public PointVectorValuePair[] getOptima() { |
| if (optima == null) { |
| throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET); |
| } |
| return optima.clone(); |
| } |
| |
| /** {@inheritDoc} */ |
| public int getMaxEvaluations() { |
| return maxEvaluations; |
| } |
| |
| /** {@inheritDoc} */ |
| public int getEvaluations() { |
| return totalEvaluations; |
| } |
| |
| /** {@inheritDoc} */ |
| public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() { |
| return optimizer.getConvergenceChecker(); |
| } |
| |
| /** |
| * {@inheritDoc} |
| */ |
| public PointVectorValuePair optimize(int maxEval, final FUNC f, |
| double[] target, double[] weights, |
| double[] startPoint) { |
| maxEvaluations = maxEval; |
| RuntimeException lastException = null; |
| optima = new PointVectorValuePair[starts]; |
| totalEvaluations = 0; |
| |
| // Multi-start loop. |
| for (int i = 0; i < starts; ++i) { |
| |
| // CHECKSTYLE: stop IllegalCatch |
| try { |
| optima[i] = optimizer.optimize(maxEval - totalEvaluations, f, target, weights, |
| i == 0 ? startPoint : generator.nextVector()); |
| } catch (ConvergenceException oe) { |
| optima[i] = null; |
| } catch (RuntimeException mue) { |
| lastException = mue; |
| optima[i] = null; |
| } |
| // CHECKSTYLE: resume IllegalCatch |
| |
| totalEvaluations += optimizer.getEvaluations(); |
| } |
| |
| sortPairs(target, weights); |
| |
| if (optima[0] == null) { |
| throw lastException; // cannot be null if starts >=1 |
| } |
| |
| // Return the found point given the best objective function value. |
| return optima[0]; |
| } |
| |
| /** |
| * Sort the optima from best to worst, followed by {@code null} elements. |
| * |
| * @param target Target value for the objective functions at optimum. |
| * @param weights Weights for the least-squares cost computation. |
| */ |
| private void sortPairs(final double[] target, |
| final double[] weights) { |
| Arrays.sort(optima, new Comparator<PointVectorValuePair>() { |
| public int compare(final PointVectorValuePair o1, |
| final PointVectorValuePair o2) { |
| if (o1 == null) { |
| return (o2 == null) ? 0 : 1; |
| } else if (o2 == null) { |
| return -1; |
| } |
| return Double.compare(weightedResidual(o1), weightedResidual(o2)); |
| } |
| private double weightedResidual(final PointVectorValuePair pv) { |
| final double[] value = pv.getValueRef(); |
| double sum = 0; |
| for (int i = 0; i < value.length; ++i) { |
| final double ri = value[i] - target[i]; |
| sum += weights[i] * ri * ri; |
| } |
| return sum; |
| } |
| }); |
| } |
| } |