blob: 18fbe1d5ed912df933263b51c3f0468dedf57fca [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.analysis.differentiation;
import org.apache.commons.math3.TestUtils;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.util.FastMath;
import org.junit.Test;
/**
* Test for class {@link GradientFunction}.
*/
public class GradientFunctionTest {
@Test
public void test2DDistance() {
EuclideanDistance f = new EuclideanDistance();
GradientFunction g = new GradientFunction(f);
for (double x = -10; x < 10; x += 0.5) {
for (double y = -10; y < 10; y += 0.5) {
double[] point = new double[] { x, y };
TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
}
}
}
@Test
public void test3DDistance() {
EuclideanDistance f = new EuclideanDistance();
GradientFunction g = new GradientFunction(f);
for (double x = -10; x < 10; x += 0.5) {
for (double y = -10; y < 10; y += 0.5) {
for (double z = -10; z < 10; z += 0.5) {
double[] point = new double[] { x, y, z };
TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
}
}
}
}
private static class EuclideanDistance implements MultivariateDifferentiableFunction {
public double value(double[] point) {
double d2 = 0;
for (double x : point) {
d2 += x * x;
}
return FastMath.sqrt(d2);
}
public DerivativeStructure value(DerivativeStructure[] point)
throws DimensionMismatchException, MathIllegalArgumentException {
DerivativeStructure d2 = point[0].getField().getZero();
for (DerivativeStructure x : point) {
d2 = d2.add(x.multiply(x));
}
return d2.sqrt();
}
public double[] gradient(double[] point) {
double[] gradient = new double[point.length];
double d = value(point);
for (int i = 0; i < point.length; ++i) {
gradient[i] = point[i] / d;
}
return gradient;
}
}
}