blob: e9155fe96847a2eeda28750f5b7e3c6d145b21a9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.distribution;
import static org.junit.Assert.assertEquals;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.NotANumberException;
import org.apache.commons.math3.exception.NotFiniteNumberException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.junit.Assert;
import org.junit.Test;
/**
* Test class for {@link EnumeratedRealDistribution}.
*
*/
public class EnumeratedRealDistributionTest {
/**
* The distribution object used for testing.
*/
private final EnumeratedRealDistribution testDistribution;
/**
* Creates the default distribution object used for testing.
*/
public EnumeratedRealDistributionTest() {
// Non-sorted singleton array with duplicates should be allowed.
// Values with zero-probability do not extend the support.
testDistribution = new EnumeratedRealDistribution(
new double[]{3.0, -1.0, 3.0, 7.0, -2.0, 8.0},
new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0});
}
/**
* Tests if the {@link EnumeratedRealDistribution} constructor throws
* exceptions for invalid data.
*/
@Test
public void testExceptions() {
EnumeratedRealDistribution invalid = null;
try {
invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0});
Assert.fail("Expected DimensionMismatchException");
} catch (DimensionMismatchException e) {
}
try{
invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, -1.0});
Assert.fail("Expected NotPositiveException");
} catch (NotPositiveException e) {
}
try {
invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, 0.0});
Assert.fail("Expected MathArithmeticException");
} catch (MathArithmeticException e) {
}
try {
invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.NaN});
Assert.fail("Expected NotANumberException");
} catch (NotANumberException e) {
}
try {
invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.POSITIVE_INFINITY});
Assert.fail("Expected NotFiniteNumberException");
} catch (NotFiniteNumberException e) {
}
Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid);
}
/**
* Tests if the distribution returns proper probability values.
*/
@Test
public void testProbability() {
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
for (int p = 0; p < points.length; p++) {
double density = testDistribution.probability(points[p]);
Assert.assertEquals(results[p], density, 0.0);
}
}
/**
* Tests if the distribution returns proper density values.
*/
@Test
public void testDensity() {
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
for (int p = 0; p < points.length; p++) {
double density = testDistribution.density(points[p]);
Assert.assertEquals(results[p], density, 0.0);
}
}
/**
* Tests if the distribution returns proper cumulative probability values.
*/
@Test
public void testCumulativeProbability() {
double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0};
for (int p = 0; p < points.length; p++) {
double probability = testDistribution.cumulativeProbability(points[p]);
Assert.assertEquals(results[p], probability, 1e-10);
}
}
/**
* Tests if the distribution returns proper mean value.
*/
@Test
public void testGetNumericalMean() {
Assert.assertEquals(3.4, testDistribution.getNumericalMean(), 1e-10);
}
/**
* Tests if the distribution returns proper variance.
*/
@Test
public void testGetNumericalVariance() {
Assert.assertEquals(7.84, testDistribution.getNumericalVariance(), 1e-10);
}
/**
* Tests if the distribution returns proper lower bound.
*/
@Test
public void testGetSupportLowerBound() {
Assert.assertEquals(-1, testDistribution.getSupportLowerBound(), 0);
}
/**
* Tests if the distribution returns proper upper bound.
*/
@Test
public void testGetSupportUpperBound() {
Assert.assertEquals(7, testDistribution.getSupportUpperBound(), 0);
}
/**
* Tests if the distribution returns properly that the support includes the
* lower bound.
*/
@Test
public void testIsSupportLowerBoundInclusive() {
Assert.assertTrue(testDistribution.isSupportLowerBoundInclusive());
}
/**
* Tests if the distribution returns properly that the support includes the
* upper bound.
*/
@Test
public void testIsSupportUpperBoundInclusive() {
Assert.assertTrue(testDistribution.isSupportUpperBoundInclusive());
}
/**
* Tests if the distribution returns properly that the support is connected.
*/
@Test
public void testIsSupportConnected() {
Assert.assertTrue(testDistribution.isSupportConnected());
}
/**
* Tests sampling.
*/
@Test
public void testSample() {
final int n = 1000000;
testDistribution.reseedRandomGenerator(-334759360); // fixed seed
final double[] samples = testDistribution.sample(n);
Assert.assertEquals(n, samples.length);
double sum = 0;
double sumOfSquares = 0;
for (int i = 0; i < samples.length; i++) {
sum += samples[i];
sumOfSquares += samples[i] * samples[i];
}
Assert.assertEquals(testDistribution.getNumericalMean(),
sum / n, 1e-2);
Assert.assertEquals(testDistribution.getNumericalVariance(),
sumOfSquares / n - FastMath.pow(sum / n, 2), 1e-2);
}
@Test
public void testIssue942() {
List<Pair<Object,Double>> list = new ArrayList<Pair<Object, Double>>();
list.add(new Pair<Object, Double>(new Object() {}, new Double(0)));
list.add(new Pair<Object, Double>(new Object() {}, new Double(1)));
Assert.assertEquals(1, new EnumeratedDistribution<Object>(list).sample(1).length);
}
@Test
public void testIssue1065() {
// Test Distribution for inverseCumulativeProbability
//
// ^
// |
// 1.000 +--------------------------------o===============
// | 3|
// | |
// | 1o=
// 0.750 +-------------------------> o== .
// | 3| . .
// | 0 | . .
// 0.5625 +---------------> o==o====== . .
// | | . . . .
// | | . . . .
// | 5| . . . .
// | | . . . .
// | o=== . . . .
// | | . . . . .
// | 4| . . . . .
// | | . . . . .
// 0.000 +=============----+--+------+--+-+--------------->
// 14 18 21 28 31 33
//
// sum = 4+5+0+3+1+3 = 16
EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(
new double[] { 14.0, 18.0, 21.0, 28.0, 31.0, 33.0 },
new double[] { 4.0 / 16.0, 5.0 / 16.0, 0.0 / 16.0, 3.0 / 16.0, 1.0 / 16.0, 3.0 / 16.0 });
assertEquals(14.0, distribution.inverseCumulativeProbability(0.0000), 0.0);
assertEquals(14.0, distribution.inverseCumulativeProbability(0.2500), 0.0);
assertEquals(33.0, distribution.inverseCumulativeProbability(1.0000), 0.0);
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5000), 0.0);
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5624), 0.0);
assertEquals(28.0, distribution.inverseCumulativeProbability(0.5626), 0.0);
assertEquals(31.0, distribution.inverseCumulativeProbability(0.7600), 0.0);
assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0);
assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0);
}
@Test
public void testCreateFromDoubles() {
final double[] data = new double[] {0, 1, 1, 2, 2, 2};
EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data);
assertEquals(0.5, distribution.probability(2), 0);
assertEquals(0.5, distribution.cumulativeProbability(1), 0);
}
}