| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math4.analysis.function; |
| |
| import org.apache.commons.math4.analysis.FunctionUtils; |
| import org.apache.commons.math4.analysis.UnivariateFunction; |
| import org.apache.commons.math4.analysis.differentiation.DerivativeStructure; |
| import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction; |
| import org.apache.commons.math4.exception.DimensionMismatchException; |
| import org.apache.commons.math4.exception.NullArgumentException; |
| import org.apache.commons.math4.exception.OutOfRangeException; |
| import org.apache.commons.rng.simple.RandomSource; |
| import org.apache.commons.rng.UniformRandomProvider; |
| import org.apache.commons.math4.util.FastMath; |
| import org.junit.Assert; |
| import org.junit.Test; |
| |
| /** |
| * Test for class {@link Logit}. |
| */ |
| public class LogitTest { |
| private final double EPS = Math.ulp(1d); |
| |
| @Test(expected=OutOfRangeException.class) |
| public void testPreconditions1() { |
| final double lo = -1; |
| final double hi = 2; |
| final UnivariateFunction f = new Logit(lo, hi); |
| |
| f.value(lo - 1); |
| } |
| |
| @Test(expected=OutOfRangeException.class) |
| public void testPreconditions2() { |
| final double lo = -1; |
| final double hi = 2; |
| final UnivariateFunction f = new Logit(lo, hi); |
| |
| f.value(hi + 1); |
| } |
| |
| @Test |
| public void testSomeValues() { |
| final double lo = 1; |
| final double hi = 2; |
| final UnivariateFunction f = new Logit(lo, hi); |
| |
| Assert.assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS); |
| Assert.assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS); |
| Assert.assertEquals(0, f.value(1.5), EPS); |
| } |
| |
| @Test |
| public void testDerivative() { |
| final double lo = 1; |
| final double hi = 2; |
| final Logit f = new Logit(lo, hi); |
| final DerivativeStructure f15 = f.value(new DerivativeStructure(1, 1, 0, 1.5)); |
| |
| Assert.assertEquals(4, f15.getPartialDerivative(1), EPS); |
| } |
| |
| @Test |
| public void testDerivativeLargeArguments() { |
| final Logit f = new Logit(1, 2); |
| |
| for (double arg : new double[] { |
| Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY |
| }) { |
| try { |
| f.value(new DerivativeStructure(1, 1, 0, arg)); |
| Assert.fail("an exception should have been thrown"); |
| } catch (OutOfRangeException ore) { |
| // expected |
| } catch (Exception e) { |
| Assert.fail("wrong exception caught: " + e.getMessage()); |
| } |
| } |
| } |
| |
| @Test |
| public void testDerivativesHighOrder() { |
| DerivativeStructure l = new Logit(1, 3).value(new DerivativeStructure(1, 5, 0, 1.2)); |
| Assert.assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16); |
| Assert.assertEquals(5.5555555555555555555, l.getPartialDerivative(1), 9.0e-16); |
| Assert.assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14); |
| Assert.assertEquals(250.34293552812071331, l.getPartialDerivative(3), 2.0e-13); |
| Assert.assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12); |
| Assert.assertEquals(75001.270131585632282, l.getPartialDerivative(5), 8.0e-11); |
| } |
| |
| @Test(expected=NullArgumentException.class) |
| public void testParametricUsage1() { |
| final Logit.Parametric g = new Logit.Parametric(); |
| g.value(0, null); |
| } |
| |
| @Test(expected=DimensionMismatchException.class) |
| public void testParametricUsage2() { |
| final Logit.Parametric g = new Logit.Parametric(); |
| g.value(0, new double[] {0}); |
| } |
| |
| @Test(expected=NullArgumentException.class) |
| public void testParametricUsage3() { |
| final Logit.Parametric g = new Logit.Parametric(); |
| g.gradient(0, null); |
| } |
| |
| @Test(expected=DimensionMismatchException.class) |
| public void testParametricUsage4() { |
| final Logit.Parametric g = new Logit.Parametric(); |
| g.gradient(0, new double[] {0}); |
| } |
| |
| @Test(expected=OutOfRangeException.class) |
| public void testParametricUsage5() { |
| final Logit.Parametric g = new Logit.Parametric(); |
| g.value(-1, new double[] {0, 1}); |
| } |
| |
| @Test(expected=OutOfRangeException.class) |
| public void testParametricUsage6() { |
| final Logit.Parametric g = new Logit.Parametric(); |
| g.value(2, new double[] {0, 1}); |
| } |
| |
| @Test |
| public void testParametricValue() { |
| final double lo = 2; |
| final double hi = 3; |
| final Logit f = new Logit(lo, hi); |
| |
| final Logit.Parametric g = new Logit.Parametric(); |
| Assert.assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0); |
| Assert.assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0); |
| Assert.assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0); |
| } |
| |
| @Test |
| public void testValueWithInverseFunction() { |
| final double lo = 2; |
| final double hi = 3; |
| final Logit f = new Logit(lo, hi); |
| final Sigmoid g = new Sigmoid(lo, hi); |
| final UniformRandomProvider random = RandomSource.create(RandomSource.WELL_1024_A, |
| 0x49914cdd9f0b8db5l); |
| final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g, |
| (UnivariateDifferentiableFunction) f); |
| |
| for (int i = 0; i < 10; i++) { |
| final double x = lo + random.nextDouble() * (hi - lo); |
| Assert.assertEquals(x, id.value(new DerivativeStructure(1, 1, 0, x)).getValue(), EPS); |
| } |
| |
| Assert.assertEquals(lo, id.value(new DerivativeStructure(1, 1, 0, lo)).getValue(), EPS); |
| Assert.assertEquals(hi, id.value(new DerivativeStructure(1, 1, 0, hi)).getValue(), EPS); |
| } |
| |
| @Test |
| public void testDerivativesWithInverseFunction() { |
| double[] epsilon = new double[] { 1.0e-20, 4.0e-16, 3.0e-15, 2.0e-11, 3.0e-9, 1.0e-6 }; |
| final double lo = 2; |
| final double hi = 3; |
| final Logit f = new Logit(lo, hi); |
| final Sigmoid g = new Sigmoid(lo, hi); |
| final UniformRandomProvider random = RandomSource.create(RandomSource.WELL_1024_A, |
| 0x96885e9c1f81cea6l); |
| final UnivariateDifferentiableFunction id = |
| FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f); |
| for (int maxOrder = 0; maxOrder < 6; ++maxOrder) { |
| double max = 0; |
| for (int i = 0; i < 10; i++) { |
| final double x = lo + random.nextDouble() * (hi - lo); |
| final DerivativeStructure dsX = new DerivativeStructure(1, maxOrder, 0, x); |
| max = FastMath.max(max, FastMath.abs(dsX.getPartialDerivative(maxOrder) - |
| id.value(dsX).getPartialDerivative(maxOrder))); |
| Assert.assertEquals("maxOrder = " + maxOrder, |
| dsX.getPartialDerivative(maxOrder), |
| id.value(dsX).getPartialDerivative(maxOrder), |
| epsilon[maxOrder]); |
| } |
| |
| // each function evaluates correctly near boundaries, |
| // but combination leads to NaN as some intermediate point is infinite |
| final DerivativeStructure dsLo = new DerivativeStructure(1, maxOrder, 0, lo); |
| if (maxOrder == 0) { |
| Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder))); |
| Assert.assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]); |
| } else if (maxOrder == 1) { |
| Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder))); |
| Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder))); |
| } else { |
| Assert.assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder))); |
| Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder))); |
| } |
| |
| final DerivativeStructure dsHi = new DerivativeStructure(1, maxOrder, 0, hi); |
| if (maxOrder == 0) { |
| Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder))); |
| Assert.assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]); |
| } else if (maxOrder == 1) { |
| Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder))); |
| Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder))); |
| } else { |
| Assert.assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder))); |
| Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder))); |
| } |
| |
| } |
| } |
| } |