blob: e248be5fa423c5431d964a180d95ea9ce824fb29 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.optimization;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.CostException;
import org.apache.commons.math.optimization.CostFunction;
import org.apache.commons.math.optimization.NelderMead;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.PointCostPair;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.commons.math.random.RandomVectorGenerator;
import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.UniformRandomGenerator;
import junit.framework.*;
public class NelderMeadTest
extends TestCase {
public NelderMeadTest(String name) {
super(name);
}
public void testCostExceptions() throws ConvergenceException {
CostFunction wrong =
new CostFunction() {
public double cost(double[] x) throws CostException {
if (x[0] < 0) {
throw new CostException("{0}", new Object[] { "oops"});
} else if (x[0] > 1) {
throw new CostException(new RuntimeException("oops"));
} else {
return x[0] * (1 - x[0]);
}
}
};
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
new double[] { -0.5 }, new double[] { 0.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
try {
new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
new double[] { 0.5 }, new double[] { 1.5 });
fail("an exception should have been thrown");
} catch (CostException ce) {
// expected behavior
assertNotNull(ce.getCause());
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testRosenbrock()
throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {
CostFunction rosenbrock =
new CostFunction() {
public double cost(double[] x) {
++count;
double a = x[1] - x[0] * x[0];
double b = 1.0 - x[0];
return 100 * a * a + b * b;
}
};
count = 0;
NelderMead nm = new NelderMead();
try {
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
}, 1, 5384353l);
fail("an exception should have been thrown");
} catch (ConvergenceException ce) {
// expected behavior
} catch (Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
count = 0;
PointCostPair optimum =
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
new double[][] {
{ -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
}, 10, 1642738l);
assertTrue(count > 700);
assertTrue(count < 800);
assertEquals(0.0, optimum.getCost(), 5.0e-5);
assertEquals(1.0, optimum.getPoint()[0], 0.01);
assertEquals(1.0, optimum.getPoint()[1], 0.01);
PointCostPair[] minima = nm.getMinima();
assertEquals(10, minima.length);
assertNotNull(minima[0]);
assertNull(minima[minima.length - 1]);
for (int i = 0; i < minima.length; ++i) {
if (minima[i] == null) {
if ((i + 1) < minima.length) {
assertTrue(minima[i+1] == null);
}
} else {
if (i > 0) {
assertTrue(minima[i-1].getCost() <= minima[i].getCost());
}
}
}
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(64453353l);
RandomVectorGenerator rvg =
new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
new double[] { 0.2, 0.2 },
new UniformRandomGenerator(rg));
optimum =
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
assertEquals(0.0, optimum.getCost(), 2.0e-4);
optimum =
nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
assertEquals(0.0, optimum.getCost(), 3.0e-5);
}
public void testPowell()
throws CostException, ConvergenceException {
CostFunction powell =
new CostFunction() {
public double cost(double[] x) {
++count;
double a = x[0] + 10 * x[1];
double b = x[2] - x[3];
double c = x[1] - 2 * x[2];
double d = x[0] - x[3];
return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
}
};
count = 0;
NelderMead nm = new NelderMead();
PointCostPair optimum =
nm.minimize(powell, 200, new ValueChecker(1.0e-3),
new double[] { 3.0, -1.0, 0.0, 1.0 },
new double[] { 4.0, 0.0, 1.0, 2.0 },
1, 1642738l);
assertTrue(count < 150);
assertEquals(0.0, optimum.getCost(), 6.0e-4);
assertEquals(0.0, optimum.getPoint()[0], 0.07);
assertEquals(0.0, optimum.getPoint()[1], 0.07);
assertEquals(0.0, optimum.getPoint()[2], 0.07);
assertEquals(0.0, optimum.getPoint()[3], 0.07);
}
private static class ValueChecker implements ConvergenceChecker {
public ValueChecker(double threshold) {
this.threshold = threshold;
}
public boolean converged(PointCostPair[] simplex) {
PointCostPair smallest = simplex[0];
PointCostPair largest = simplex[simplex.length - 1];
return (largest.getCost() - smallest.getCost()) < threshold;
}
private double threshold;
};
public static Test suite() {
return new TestSuite(NelderMeadTest.class);
}
private int count;
}