| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math4.optim.univariate; |
| |
| import java.util.Arrays; |
| import java.util.Comparator; |
| |
| import org.apache.commons.math4.exception.MathIllegalStateException; |
| import org.apache.commons.math4.exception.NotStrictlyPositiveException; |
| import org.apache.commons.math4.exception.util.LocalizedFormats; |
| import org.apache.commons.math4.optim.MaxEval; |
| import org.apache.commons.math4.optim.OptimizationData; |
| import org.apache.commons.math4.optim.nonlinear.scalar.GoalType; |
| import org.apache.commons.rng.UniformRandomProvider; |
| |
| /** |
| * Special implementation of the {@link UnivariateOptimizer} interface |
| * adding multi-start features to an existing optimizer. |
| * <br> |
| * This class wraps an optimizer in order to use it several times in |
| * turn with different starting points (trying to avoid being trapped |
| * in a local extremum when looking for a global one). |
| * |
| * @since 3.0 |
| */ |
| public class MultiStartUnivariateOptimizer |
| extends UnivariateOptimizer { |
| /** Underlying classical optimizer. */ |
| private final UnivariateOptimizer optimizer; |
| /** Number of evaluations already performed for all starts. */ |
| private int totalEvaluations; |
| /** Number of starts to go. */ |
| private final int starts; |
| /** Random generator for multi-start. */ |
| private final UniformRandomProvider generator; |
| /** Found optima. */ |
| private UnivariatePointValuePair[] optima; |
| /** Optimization data. */ |
| private OptimizationData[] optimData; |
| /** |
| * Location in {@link #optimData} where the updated maximum |
| * number of evaluations will be stored. |
| */ |
| private int maxEvalIndex = -1; |
| /** |
| * Location in {@link #optimData} where the updated start value |
| * will be stored. |
| */ |
| private int searchIntervalIndex = -1; |
| |
| /** |
| * Create a multi-start optimizer from a single-start optimizer. |
| * |
| * @param optimizer Single-start optimizer to wrap. |
| * @param starts Number of starts to perform. If {@code starts == 1}, |
| * the {@code optimize} methods will return the same solution as |
| * {@code optimizer} would. |
| * @param generator Random generator to use for restarts. |
| * @throws NotStrictlyPositiveException if {@code starts < 1}. |
| */ |
| public MultiStartUnivariateOptimizer(final UnivariateOptimizer optimizer, |
| final int starts, |
| final UniformRandomProvider generator) { |
| super(optimizer.getConvergenceChecker()); |
| |
| if (starts < 1) { |
| throw new NotStrictlyPositiveException(starts); |
| } |
| |
| this.optimizer = optimizer; |
| this.starts = starts; |
| this.generator = generator; |
| } |
| |
| /** {@inheritDoc} */ |
| @Override |
| public int getEvaluations() { |
| return totalEvaluations; |
| } |
| |
| /** |
| * Gets all the optima found during the last call to {@code optimize}. |
| * The optimizer stores all the optima found during a set of |
| * restarts. The {@code optimize} method returns the best point only. |
| * This method returns all the points found at the end of each starts, |
| * including the best one already returned by the {@code optimize} method. |
| * <br> |
| * The returned array as one element for each start as specified |
| * in the constructor. It is ordered with the results from the |
| * runs that did converge first, sorted from best to worst |
| * objective value (i.e in ascending order if minimizing and in |
| * descending order if maximizing), followed by {@code null} elements |
| * corresponding to the runs that did not converge. This means all |
| * elements will be {@code null} if the {@code optimize} method did throw |
| * an exception. |
| * This also means that if the first element is not {@code null}, it is |
| * the best point found across all starts. |
| * |
| * @return an array containing the optima. |
| * @throws MathIllegalStateException if {@link #optimize(OptimizationData[]) |
| * optimize} has not been called. |
| */ |
| public UnivariatePointValuePair[] getOptima() { |
| if (optima == null) { |
| throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET); |
| } |
| return optima.clone(); |
| } |
| |
| /** |
| * {@inheritDoc} |
| * |
| * @throws MathIllegalStateException if {@code optData} does not contain an |
| * instance of {@link MaxEval} or {@link SearchInterval}. |
| */ |
| @Override |
| public UnivariatePointValuePair optimize(OptimizationData... optData) { |
| // Store arguments in order to pass them to the internal optimizer. |
| optimData = optData; |
| // Set up base class and perform computations. |
| return super.optimize(optData); |
| } |
| |
| /** {@inheritDoc} */ |
| @Override |
| protected UnivariatePointValuePair doOptimize() { |
| // Remove all instances of "MaxEval" and "SearchInterval" from the |
| // array that will be passed to the internal optimizer. |
| // The former is to enforce smaller numbers of allowed evaluations |
| // (according to how many have been used up already), and the latter |
| // to impose a different start value for each start. |
| for (int i = 0; i < optimData.length; i++) { |
| if (optimData[i] instanceof MaxEval) { |
| optimData[i] = null; |
| maxEvalIndex = i; |
| continue; |
| } |
| if (optimData[i] instanceof SearchInterval) { |
| optimData[i] = null; |
| searchIntervalIndex = i; |
| continue; |
| } |
| } |
| if (maxEvalIndex == -1) { |
| throw new MathIllegalStateException(); |
| } |
| if (searchIntervalIndex == -1) { |
| throw new MathIllegalStateException(); |
| } |
| |
| RuntimeException lastException = null; |
| optima = new UnivariatePointValuePair[starts]; |
| totalEvaluations = 0; |
| |
| final int maxEval = getMaxEvaluations(); |
| final double min = getMin(); |
| final double max = getMax(); |
| final double startValue = getStartValue(); |
| |
| // Multi-start loop. |
| for (int i = 0; i < starts; i++) { |
| // CHECKSTYLE: stop IllegalCatch |
| try { |
| // Decrease number of allowed evaluations. |
| optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations); |
| // New start value. |
| final double s = (i == 0) ? |
| startValue : |
| min + generator.nextDouble() * (max - min); |
| optimData[searchIntervalIndex] = new SearchInterval(min, max, s); |
| // Optimize. |
| optima[i] = optimizer.optimize(optimData); |
| } catch (RuntimeException mue) { |
| lastException = mue; |
| optima[i] = null; |
| } |
| // CHECKSTYLE: resume IllegalCatch |
| |
| totalEvaluations += optimizer.getEvaluations(); |
| } |
| |
| sortPairs(getGoalType()); |
| |
| if (optima[0] == null) { |
| throw lastException; // Cannot be null if starts >= 1. |
| } |
| |
| // Return the point with the best objective function value. |
| return optima[0]; |
| } |
| |
| /** |
| * Sort the optima from best to worst, followed by {@code null} elements. |
| * |
| * @param goal Goal type. |
| */ |
| private void sortPairs(final GoalType goal) { |
| Arrays.sort(optima, new Comparator<UnivariatePointValuePair>() { |
| /** {@inheritDoc} */ |
| @Override |
| public int compare(final UnivariatePointValuePair o1, |
| final UnivariatePointValuePair o2) { |
| if (o1 == null) { |
| return (o2 == null) ? 0 : 1; |
| } else if (o2 == null) { |
| return -1; |
| } |
| final double v1 = o1.getValue(); |
| final double v2 = o2.getValue(); |
| return (goal == GoalType.MINIMIZE) ? |
| Double.compare(v1, v2) : Double.compare(v2, v1); |
| } |
| }); |
| } |
| } |