Add test assumption to the discrete distribution sampling test
This requires the distribution is tested against at least 50% of the
PMF.
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java
index fcc46b3..763c67f 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java
@@ -465,29 +465,43 @@
@Test
void testSampling() {
+ // This test uses the points that are used to test the distribution PMF.
+ // The sum of the probability values does not have to be 1 (or very close to 1).
+ // Any value generated by the sampler that is not an expected point will
+ // be ignored. If the sum of probabilities is above 0.5 then at least half
+ // of the samples should be counted and the test will verify these occur with
+ // the expected relative frequencies. Note: The expected values are normalised
+ // to 1 (i.e. relative frequencies) by the Chi-square test.
+ final int[] probabilityPoints = makeProbabilityTestPoints();
+ final double[] probabilityValues = makeProbabilityTestValues();
+ final int length = TestUtils.eliminateZeroMassPoints(probabilityPoints, probabilityValues);
+ final double[] expected = Arrays.copyOf(probabilityValues, length);
+
+ // This test will not be valid if the points do not represent enough of the PMF.
+ // Require at least 50%.
+ final double sum = Arrays.stream(expected).sum();
+ if (sum < 0.5) {
+ Assertions.fail("Not enough of the PMF is tested during sampling: " + sum);
+ }
+
// Use fixed seed.
final int sampleSize = 1000;
final DiscreteDistribution.Sampler sampler =
getDistribution().createSampler(RandomSource.create(RandomSource.WELL_512_A, 1000));
final int[] sample = TestUtils.sample(sampleSize, sampler);
- final int[] densityPoints = makeProbabilityTestPoints();
- final double[] densityValues = makeProbabilityTestValues();
- final int length = TestUtils.eliminateZeroMassPoints(densityPoints, densityValues);
- final double[] expected = Arrays.copyOf(densityValues, length);
-
final long[] counts = new long[length];
for (int i = 0; i < sampleSize; i++) {
final int x = sample[i];
for (int j = 0; j < length; j++) {
- if (x == densityPoints[j]) {
+ if (x == probabilityPoints[j]) {
counts[j]++;
break;
}
}
}
- TestUtils.assertChiSquareAccept(densityPoints, expected, counts, 0.001);
+ TestUtils.assertChiSquareAccept(probabilityPoints, expected, counts, 0.001);
}
/**