| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.commons.math3.distribution; |
| |
| import static org.junit.Assert.assertEquals; |
| |
| import java.util.ArrayList; |
| import java.util.List; |
| |
| import org.apache.commons.math3.exception.DimensionMismatchException; |
| import org.apache.commons.math3.exception.MathArithmeticException; |
| import org.apache.commons.math3.exception.NotANumberException; |
| import org.apache.commons.math3.exception.NotFiniteNumberException; |
| import org.apache.commons.math3.exception.NotPositiveException; |
| import org.apache.commons.math3.util.FastMath; |
| import org.apache.commons.math3.util.Pair; |
| import org.junit.Assert; |
| import org.junit.Test; |
| |
| /** |
| * Test class for {@link EnumeratedRealDistribution}. |
| * |
| */ |
| public class EnumeratedRealDistributionTest { |
| |
| /** |
| * The distribution object used for testing. |
| */ |
| private final EnumeratedRealDistribution testDistribution; |
| |
| /** |
| * Creates the default distribution object used for testing. |
| */ |
| public EnumeratedRealDistributionTest() { |
| // Non-sorted singleton array with duplicates should be allowed. |
| // Values with zero-probability do not extend the support. |
| testDistribution = new EnumeratedRealDistribution( |
| new double[]{3.0, -1.0, 3.0, 7.0, -2.0, 8.0}, |
| new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0}); |
| } |
| |
| /** |
| * Tests if the {@link EnumeratedRealDistribution} constructor throws |
| * exceptions for invalid data. |
| */ |
| @Test |
| public void testExceptions() { |
| EnumeratedRealDistribution invalid = null; |
| try { |
| invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0}); |
| Assert.fail("Expected DimensionMismatchException"); |
| } catch (DimensionMismatchException e) { |
| } |
| try{ |
| invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, -1.0}); |
| Assert.fail("Expected NotPositiveException"); |
| } catch (NotPositiveException e) { |
| } |
| try { |
| invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, 0.0}); |
| Assert.fail("Expected MathArithmeticException"); |
| } catch (MathArithmeticException e) { |
| } |
| try { |
| invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.NaN}); |
| Assert.fail("Expected NotANumberException"); |
| } catch (NotANumberException e) { |
| } |
| try { |
| invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.POSITIVE_INFINITY}); |
| Assert.fail("Expected NotFiniteNumberException"); |
| } catch (NotFiniteNumberException e) { |
| } |
| Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid); |
| } |
| |
| /** |
| * Tests if the distribution returns proper probability values. |
| */ |
| @Test |
| public void testProbability() { |
| double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; |
| double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0}; |
| for (int p = 0; p < points.length; p++) { |
| double density = testDistribution.probability(points[p]); |
| Assert.assertEquals(results[p], density, 0.0); |
| } |
| } |
| |
| /** |
| * Tests if the distribution returns proper density values. |
| */ |
| @Test |
| public void testDensity() { |
| double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; |
| double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0}; |
| for (int p = 0; p < points.length; p++) { |
| double density = testDistribution.density(points[p]); |
| Assert.assertEquals(results[p], density, 0.0); |
| } |
| } |
| |
| /** |
| * Tests if the distribution returns proper cumulative probability values. |
| */ |
| @Test |
| public void testCumulativeProbability() { |
| double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; |
| double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0}; |
| for (int p = 0; p < points.length; p++) { |
| double probability = testDistribution.cumulativeProbability(points[p]); |
| Assert.assertEquals(results[p], probability, 1e-10); |
| } |
| } |
| |
| /** |
| * Tests if the distribution returns proper mean value. |
| */ |
| @Test |
| public void testGetNumericalMean() { |
| Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10); |
| } |
| |
| /** |
| * Tests if the distribution returns proper variance. |
| */ |
| @Test |
| public void testGetNumericalVariance() { |
| Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10); |
| } |
| |
| /** |
| * Tests if the distribution returns proper lower bound. |
| */ |
| @Test |
| public void testGetSupportLowerBound() { |
| Assert.assertEquals(-1, testDistribution.getSupportLowerBound(), 0); |
| } |
| |
| /** |
| * Tests if the distribution returns proper upper bound. |
| */ |
| @Test |
| public void testGetSupportUpperBound() { |
| Assert.assertEquals(7, testDistribution.getSupportUpperBound(), 0); |
| } |
| |
| /** |
| * Tests if the distribution returns properly that the support includes the |
| * lower bound. |
| */ |
| @Test |
| public void testIsSupportLowerBoundInclusive() { |
| Assert.assertTrue(testDistribution.isSupportLowerBoundInclusive()); |
| } |
| |
| /** |
| * Tests if the distribution returns properly that the support includes the |
| * upper bound. |
| */ |
| @Test |
| public void testIsSupportUpperBoundInclusive() { |
| Assert.assertTrue(testDistribution.isSupportUpperBoundInclusive()); |
| } |
| |
| /** |
| * Tests if the distribution returns properly that the support is connected. |
| */ |
| @Test |
| public void testIsSupportConnected() { |
| Assert.assertTrue(testDistribution.isSupportConnected()); |
| } |
| |
| /** |
| * Tests sampling. |
| */ |
| @Test |
| public void testSample() { |
| final int n = 1000000; |
| testDistribution.reseedRandomGenerator(-334759360); // fixed seed |
| final double[] samples = testDistribution.sample(n); |
| Assert.assertEquals(n, samples.length); |
| double sum = 0; |
| double sumOfSquares = 0; |
| for (int i = 0; i < samples.length; i++) { |
| sum += samples[i]; |
| sumOfSquares += samples[i] * samples[i]; |
| } |
| Assert.assertEquals(testDistribution.getNumericalMean(), |
| sum / n, 1e-2); |
| Assert.assertEquals(testDistribution.getNumericalVariance(), |
| sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2); |
| } |
| |
| @Test |
| public void testIssue942() { |
| List<Pair<Object,Double>> list = new ArrayList<Pair<Object, Double>>(); |
| list.add(new Pair<Object, Double>(new Object() {}, new Double(0))); |
| list.add(new Pair<Object, Double>(new Object() {}, new Double(1))); |
| Assert.assertEquals(1, new EnumeratedDistribution<Object>(list).sample(1).length); |
| } |
| |
| @Test |
| public void testIssue1065() { |
| // Test Distribution for inverseCumulativeProbability |
| // |
| // ^ |
| // | |
| // 1.000 +--------------------------------o=============== |
| // | 3| |
| // | | |
| // | 1o= |
| // 0.750 +-------------------------> o== . |
| // | 3| . . |
| // | 0 | . . |
| // 0.5625 +---------------> o==o====== . . |
| // | | . . . . |
| // | | . . . . |
| // | 5| . . . . |
| // | | . . . . |
| // | o=== . . . . |
| // | | . . . . . |
| // | 4| . . . . . |
| // | | . . . . . |
| // 0.000 +=============----+--+------+--+-+---------------> |
| // 14 18 21 28 31 33 |
| // |
| // sum = 4+5+0+3+1+3 = 16 |
| |
| EnumeratedRealDistribution distribution = new EnumeratedRealDistribution( |
| new double[] { 14.0, 18.0, 21.0, 28.0, 31.0, 33.0 }, |
| new double[] { 4.0 / 16.0, 5.0 / 16.0, 0.0 / 16.0, 3.0 / 16.0, 1.0 / 16.0, 3.0 / 16.0 }); |
| |
| assertEquals(14.0, distribution.inverseCumulativeProbability(0.0000), 0.0); |
| assertEquals(14.0, distribution.inverseCumulativeProbability(0.2500), 0.0); |
| assertEquals(33.0, distribution.inverseCumulativeProbability(1.0000), 0.0); |
| |
| assertEquals(18.0, distribution.inverseCumulativeProbability(0.5000), 0.0); |
| assertEquals(18.0, distribution.inverseCumulativeProbability(0.5624), 0.0); |
| assertEquals(28.0, distribution.inverseCumulativeProbability(0.5626), 0.0); |
| assertEquals(31.0, distribution.inverseCumulativeProbability(0.7600), 0.0); |
| assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0); |
| assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0); |
| } |
| |
| @Test |
| public void testCreateFromDoubles() { |
| final double[] data = new double[] {0, 1, 1, 2, 2, 2}; |
| EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data); |
| assertEquals(0.5, distribution.probability(2), 0); |
| assertEquals(0.5, distribution.cumulativeProbability(1), 0); |
| } |
| } |