STATISTICS-32: New "survivalProbability" function for all discrete
distributions.
While a naive implementation would simply be
`1-cumulativeProbability`, that would result
in loss of precision.
For many of the current discrete distributions a higher
precision survival probability is calculated.
For others, it is simply `1-cumulativeProbability`.
Many tests were added to verify the following:
- the precision of cumulativeProbability
- the precision of survivalProbability
- That survivalProbabiliy is near 1-cumulative
- That survival and cumulative probabilities are
complementary
Through this development, certain distributions
were found lacking precision for their
cumulativeProbabilities and were improved.
These were:
- BinomialDistribution
- HypergeometricDistribuion
Expanding the tests for the Pascal distribution for the degenerate cases
found a bug in the p=1 degenerate case when x=0. The NaN return value
has been corrected to either 0 (x=0) or -infinity (x!=0).
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java
index e3090c0..0e382ab 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java
@@ -101,8 +101,24 @@
} else if (x >= numberOfTrials) {
ret = 1.0;
} else {
- ret = 1.0 - RegularizedBeta.value(probabilityOfSuccess,
- x + 1.0, (double) numberOfTrials - x);
+ // Use a helper function to compute the complement of the survival probability
+ ret = RegularizedBetaUtils.complement(probabilityOfSuccess,
+ x + 1.0, (double) numberOfTrials - x);
+ }
+ return ret;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public double survivalProbability(int x) {
+ double ret;
+ if (x < 0) {
+ ret = 1.0;
+ } else if (x >= numberOfTrials) {
+ ret = 0.0;
+ } else {
+ ret = RegularizedBeta.value(probabilityOfSuccess,
+ x + 1.0, (double) numberOfTrials - x);
}
return ret;
}
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java
index e2d6cb9..292828c 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java
@@ -73,6 +73,23 @@
double cumulativeProbability(int x);
/**
+ * For a random variable {@code X} whose values are distributed according
+ * to this distribution, this method returns {@code P(X > x)}.
+ * In other words, this method represents the complementary cumulative
+ * distribution function.
+ * <p>
+ * By default, this is defined as {@code 1 - cumulativeProbability(x)}, but
+ * the specific implementation may be more accurate.
+ *
+ * @param x Point at which the survival function is evaluated.
+ * @return the probability that a random variable with this
+ * distribution takes a value greater than {@code x}.
+ */
+ default double survivalProbability(int x) {
+ return 1.0 - cumulativeProbability(x);
+ }
+
+ /**
* Computes the quantile function of this distribution.
* For a random variable {@code X} distributed according to this distribution,
* the returned value is
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java
index d9e88cc..5e28975 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java
@@ -79,6 +79,15 @@
return -Math.expm1(log1mProbabilityOfSuccess * (x + 1));
}
+ /** {@inheritDoc} */
+ @Override
+ public double survivalProbability(int x) {
+ if (x < 0) {
+ return 1.0;
+ }
+ return Math.exp(log1mProbabilityOfSuccess * (x + 1));
+ }
+
/**
* {@inheritDoc}
*
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java
index cdeab67..8d4596c 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java
@@ -79,7 +79,24 @@
} else if (x >= domain[1]) {
ret = 1.0;
} else {
- ret = innerCumulativeProbability(domain[0], x, 1);
+ ret = innerCumulativeProbability(domain[0], x);
+ }
+
+ return ret;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public double survivalProbability(int x) {
+ double ret;
+
+ final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
+ if (x < domain[0]) {
+ ret = 1.0;
+ } else if (x >= domain[1]) {
+ ret = 0.0;
+ } else {
+ ret = innerCumulativeProbability(domain[1], x + 1);
}
return ret;
@@ -168,22 +185,36 @@
} else {
final double p = (double) sampleSize / (double) populationSize;
final double q = (double) (populationSize - sampleSize) / (double) populationSize;
- final double p1 = SaddlePointExpansionUtils.logBinomialProbability(x,
- numberOfSuccesses, p, q);
- final double p2 =
- SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
- populationSize - numberOfSuccesses, p, q);
- final double p3 =
- SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, p, q);
- ret = p1 + p2 - p3;
+ ret = logProbability(x, p, q);
}
return ret;
}
/**
+ * Compute the log probability.
+ *
+ * @param x Value.
+ * @param p sample size / population size.
+ * @param q (population size - sample size) / population size
+ * @return log(P(X = x))
+ */
+ private double logProbability(int x, double p, double q) {
+ final double p1 =
+ SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, p, q);
+ final double p2 =
+ SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
+ populationSize - numberOfSuccesses, p, q);
+ final double p3 =
+ SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, p, q);
+ return p1 + p2 - p3;
+ }
+
+ /**
* For this distribution, {@code X}, this method returns {@code P(X >= x)}.
*
+ * <p>Note: This is not equals to {@link #survivalProbability(int)} which computes {@code P(X > x)}.
+ *
* @param x Value at which the CDF is evaluated.
* @return the upper tail CDF for this distribution.
*/
@@ -196,7 +227,7 @@
} else if (x > domain[1]) {
ret = 0.0;
} else {
- ret = innerCumulativeProbability(domain[1], x, -1);
+ ret = innerCumulativeProbability(domain[1], x);
}
return ret;
@@ -206,21 +237,32 @@
* For this distribution, {@code X}, this method returns
* {@code P(x0 <= X <= x1)}.
* This probability is computed by summing the point probabilities for the
- * values {@code x0, x0 + 1, x0 + 2, ..., x1}, in the order directed by
- * {@code dx}.
+ * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
+ * using a comparison of the input bounds.
+ * This should be called by using {@code x0} as the domain limit and {@code x1}
+ * as the internal value. This will result in an initial sum of increasing larger magnitudes.
*
- * @param x0 Inclusive lower bound.
- * @param x1 Inclusive upper bound.
- * @param dx Direction of summation (1 indicates summing from x0 to x1, and
- * 0 indicates summing from x1 to x0).
+ * @param x0 Inclusive domain bound.
+ * @param x1 Inclusive internal bound.
* @return {@code P(x0 <= X <= x1)}.
*/
- private double innerCumulativeProbability(int x0, int x1, int dx) {
+ private double innerCumulativeProbability(int x0, int x1) {
+ // Assume the range is within the domain.
+ // Reuse the computation for probability(x) but avoid checking the domain for each call.
+ final double p = (double) sampleSize / (double) populationSize;
+ final double q = (double) (populationSize - sampleSize) / (double) populationSize;
int x = x0;
- double ret = probability(x);
- while (x != x1) {
- x += dx;
- ret += probability(x);
+ double ret = Math.exp(logProbability(x, p, q));
+ if (x0 < x1) {
+ while (x != x1) {
+ x++;
+ ret += Math.exp(logProbability(x, p, q));
+ }
+ } else {
+ while (x != x1) {
+ x--;
+ ret += Math.exp(logProbability(x, p, q));
+ }
}
return ret;
}
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java
index d2cdfdc..94d385b 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java
@@ -108,6 +108,9 @@
double ret;
if (x < 0) {
ret = 0.0;
+ } else if (x == 0) {
+ // Special case exploiting cancellation.
+ ret = Math.pow(probabilityOfSuccess, numberOfSuccesses);
} else {
ret = BinomialCoefficientDouble.value(x +
numberOfSuccesses - 1, numberOfSuccesses - 1) *
@@ -123,6 +126,9 @@
double ret;
if (x < 0) {
ret = Double.NEGATIVE_INFINITY;
+ } else if (x == 0) {
+ // Special case exploiting cancellation.
+ ret = logProbabilityOfSuccess * numberOfSuccesses;
} else {
ret = LogBinomialCoefficient.value(x +
numberOfSuccesses - 1, numberOfSuccesses - 1) +
@@ -145,6 +151,20 @@
return ret;
}
+ /** {@inheritDoc} */
+ @Override
+ public double survivalProbability(int x) {
+ double ret;
+ if (x < 0) {
+ ret = 1.0;
+ } else {
+ // Use a helper function to compute the complement of the cumulative probability
+ ret = RegularizedBetaUtils.complement(probabilityOfSuccess,
+ numberOfSuccesses, x + 1.0);
+ }
+ return ret;
+ }
+
/**
* {@inheritDoc}
*
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java
index a1cd928..c2649fa 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java
@@ -108,6 +108,19 @@
maxIterations);
}
+ /** {@inheritDoc} */
+ @Override
+ public double survivalProbability(int x) {
+ if (x < 0) {
+ return 1;
+ }
+ if (x == Integer.MAX_VALUE) {
+ return 0;
+ }
+ return RegularizedGamma.P.value((double) x + 1, mean, epsilon,
+ maxIterations);
+ }
+
/**
* Calculates the Poisson distribution function using a normal
* approximation. The {@code N(mean, sqrt(mean))} distribution is used
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/RegularizedBetaUtils.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/RegularizedBetaUtils.java
new file mode 100644
index 0000000..eb8db88
--- /dev/null
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/RegularizedBetaUtils.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.statistics.distribution;
+
+import org.apache.commons.numbers.gamma.RegularizedBeta;
+
+/**
+ * Utilities for the <a href="http://mathworld.wolfram.com/RegularizedBetaFunction.html">
+ * Regularized Beta function</a> {@code I(x, a, b)}.
+ */
+final class RegularizedBetaUtils {
+ /** No instances. */
+ private RegularizedBetaUtils() {}
+
+ /**
+ * Compute the complement of the regularized beta function {@code I(x, a, b)}.
+ * <pre>
+ * 1 - I(x, a, b) = I(1 - x, b, a)
+ * </pre>
+ *
+ * @param x the value.
+ * @param a Parameter {@code a}.
+ * @param b Parameter {@code b}.
+ * @return the complement of the regularized beta function 1 - I(x, a, b).
+ */
+ static double complement(double x, double a, double b) {
+ // Identity of the regularized beta function: 1 - I_z(a, b) = I_{1-x}(b, a)
+ // Ideally call RegularizedBeta.value(1 - x, b, a) to maximise precision.
+ //
+ // The implementation of the beta function will use the complement based on a condition.
+ // Here we repeat the condition with a and b switched and testing 1 - x.
+ // This will avoid double inversion of the parameters.
+ final double mxp1 = 1 - x;
+ if (mxp1 > (b + 1) / (2 + b + a)) {
+ // Note: This drops the addition test '&& x <= (a + 1) / (2 + b + a)'
+ // The test is to avoid infinite method call recursion which does not apply
+ // in this case. See MATH-1067.
+
+ // Direct computation of the complement with the input x.
+ // Avoids loss of precision when x != 1 - (1-x)
+ return 1.0 - RegularizedBeta.value(x, a, b);
+ }
+ // Use the identity which should be computed directly by the RegularizedBeta implementation.
+ return RegularizedBeta.value(mxp1, b, a);
+ }
+}
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java
index 32147d1..f3fdac8 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java
@@ -78,6 +78,29 @@
//-------------------- Additional test cases -------------------------------
+ /** Test case n = 10, p = 0.3. */
+ @Test
+ void testSmallPValue() {
+ final BinomialDistribution dist = new BinomialDistribution(10, 0.3);
+ setDistribution(dist);
+ // computed using R version 3.4.4
+ setCumulativeTestValues(new double[] {0.00000000000000000000, 0.02824752489999998728, 0.14930834590000002793,
+ 0.38278278639999974153, 0.64961071840000017552, 0.84973166740000016794, 0.95265101260000006889,
+ 0.98940792160000001765, 0.99840961360000002323, 0.99985631409999997654, 0.99999409509999992451,
+ 1.00000000000000000000, 1.00000000000000000000});
+ setDensityTestValues(new double[] {0.0000000000000000000e+00, 2.8247524899999980341e-02,
+ 1.2106082099999991575e-01, 2.3347444049999999116e-01, 2.6682793199999993439e-01, 2.0012094900000007569e-01,
+ 1.0291934520000002584e-01, 3.6756909000000004273e-02, 9.0016919999999864960e-03, 1.4467005000000008035e-03,
+ 1.3778099999999990615e-04, 5.9048999999999949131e-06, 0.0000000000000000000e+00});
+ setInverseCumulativeTestValues(new int[] {0, 0, 0, 0, 1, 1, 8, 7, 6, 5, 5, 10});
+ verifyDensities();
+ verifyLogDensities();
+ verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
+ verifyInverseCumulativeProbabilities();
+ }
+
/** Test degenerate case p = 0 */
@Test
void testDegenerate0() {
@@ -90,7 +113,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {0, 0});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
Assertions.assertEquals(0, dist.getSupportLowerBound());
Assertions.assertEquals(0, dist.getSupportUpperBound());
@@ -108,7 +134,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {5, 5});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
Assertions.assertEquals(5, dist.getSupportLowerBound());
Assertions.assertEquals(5, dist.getSupportUpperBound());
@@ -126,7 +155,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {0, 0});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
Assertions.assertEquals(0, dist.getSupportLowerBound());
Assertions.assertEquals(0, dist.getSupportUpperBound());
@@ -184,4 +216,22 @@
Assertions.assertEquals(trials / 2, p);
}
}
+
+ @Test
+ void testHighPrecisionCumulativeProbabilities() {
+ // computed using R version 3.4.4
+ setDistribution(new BinomialDistribution(100, 0.99));
+ setCumulativePrecisionTestPoints(new int[] {82, 81});
+ setCumulativePrecisionTestValues(new double[] {1.4061271955993513664e-17, 6.1128083336354843707e-19});
+ verifyCumulativeProbabilityPrecision();
+ }
+
+ @Test
+ void testHighPrecisionSurvivalProbabilities() {
+ // computed using R version 3.4.4
+ setDistribution(new BinomialDistribution(100, 0.01));
+ setSurvivalPrecisionTestPoints(new int[] {18, 19});
+ setSurvivalPrecisionTestValues(new double[] {6.1128083336353977038e-19, 2.4944165604029235392e-20});
+ verifySurvivalProbabilityPrecision();
+ }
}
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java
index 521dcf5..0df01dd 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java
@@ -16,6 +16,7 @@
*/
package org.apache.commons.statistics.distribution;
+import java.util.Arrays;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
@@ -38,6 +39,20 @@
* makeInverseCumulativeTestPoints() -- arguments used to test inverse cdf evaluation
* makeInverseCumulativeTestValues() -- expected inverse cdf values
* <p>
+ * If the discrete distribution provides higher precision implementations of cumulativeProbability
+ * and/or survivalProbability, the following methods should be implemented to provide testing.
+ * To use these tests, calculate the cumulativeProbability and survivalProbability such that their naive
+ * complement is exceptionally close to `1` and consequently could lose precision due to floating point
+ * arithmetic.
+ *
+ * NOTE: The default high-precision threshold is 1e-22.
+ * <pre>
+ * makeCumulativePrecisionTestPoints() -- high precision test inputs
+ * makeCumulativePrecisionTestValues() -- high precision expected results
+ * makeSurvivalPrecisionTestPoints() -- high precision test inputs
+ * makeSurvivalPrecisionTestValues() -- high precision expected results
+ * </pre>
+ * <p>
* To implement additional test cases with different distribution instances and test data,
* use the setXxx methods for the instance data in test cases and call the verifyXxx methods
* to verify results.
@@ -51,6 +66,9 @@
/** Tolerance used in comparing expected and returned values. */
private double tolerance = 1e-12;
+ /** Tolerance used in high precision tests. */
+ private double highPrecisionTolerance = 1e-22;
+
/** Arguments used to test probability density calculations. */
private int[] densityTestPoints;
@@ -66,6 +84,18 @@
/** Values used to test cumulative probability density calculations. */
private double[] cumulativeTestValues;
+ /** Arguments used to test cumulative probability precision, effectively any x where 1-cdf(x) would result in 1. */
+ private int[] cumulativePrecisionTestPoints;
+
+ /** Values used to test cumulative probability precision, usually exceptionally tiny values. */
+ private double[] cumulativePrecisionTestValues;
+
+ /** Arguments used to test survival probability precision, effectively any x where 1-sf(x) would result in 1. */
+ private int[] survivalPrecisionTestPoints;
+
+ /** Values used to test survival probability precision, usually exceptionally tiny values. */
+ private double[] survivalPrecisionTestValues;
+
/** Arguments used to test inverse cumulative probability density calculations. */
private double[] inverseCumulativeTestPoints;
@@ -91,12 +121,7 @@
* @return double[] the default logarithmic probability density test expected values.
*/
public double[] makeLogDensityTestValues() {
- final double[] density = makeDensityTestValues();
- final double[] logDensity = new double[density.length];
- for (int i = 0; i < density.length; i++) {
- logDensity[i] = Math.log(density[i]);
- }
- return logDensity;
+ return Arrays.stream(makeDensityTestValues()).map(Math::log).toArray();
}
/** Creates the default cumulative probability density test input values. */
@@ -105,6 +130,34 @@
/** Creates the default cumulative probability density test expected values. */
public abstract double[] makeCumulativeTestValues();
+ /** Creates the default cumulative probability precision test input values. */
+ public int[] makeCumulativePrecisionTestPoints() {
+ return new int[0];
+ }
+
+ /**
+ * Creates the default cumulative probability precision test expected values.
+ * Note: The default threshold is 1e-22, any expected values with much higher precision may
+ * not test the desired results without increasing precision threshold.
+ */
+ public double[] makeCumulativePrecisionTestValues() {
+ return new double[0];
+ }
+
+ /** Creates the default survival probability precision test input values. */
+ public int[] makeSurvivalPrecisionTestPoints() {
+ return new int[0];
+ }
+
+ /**
+ * Creates the default survival probability precision test expected values.
+ * Note: The default threshold is 1e-22, any expected values with much higher precision may
+ * not test the desired results without increasing precision threshold.
+ */
+ public double[] makeSurvivalPrecisionTestValues() {
+ return new double[0];
+ }
+
/** Creates the default inverse cumulative probability test input values. */
public abstract double[] makeInverseCumulativeTestPoints();
@@ -124,6 +177,10 @@
logDensityTestValues = makeLogDensityTestValues();
cumulativeTestPoints = makeCumulativeTestPoints();
cumulativeTestValues = makeCumulativeTestValues();
+ cumulativePrecisionTestPoints = makeCumulativePrecisionTestPoints();
+ cumulativePrecisionTestValues = makeCumulativePrecisionTestValues();
+ survivalPrecisionTestPoints = makeSurvivalPrecisionTestPoints();
+ survivalPrecisionTestValues = makeSurvivalPrecisionTestValues();
inverseCumulativeTestPoints = makeInverseCumulativeTestPoints();
inverseCumulativeTestValues = makeInverseCumulativeTestValues();
}
@@ -139,6 +196,10 @@
logDensityTestValues = null;
cumulativeTestPoints = null;
cumulativeTestValues = null;
+ cumulativePrecisionTestPoints = null;
+ cumulativePrecisionTestValues = null;
+ survivalPrecisionTestPoints = null;
+ survivalPrecisionTestValues = null;
inverseCumulativeTestPoints = null;
inverseCumulativeTestValues = null;
}
@@ -164,10 +225,9 @@
*/
protected void verifyLogDensities() {
for (int i = 0; i < densityTestPoints.length; i++) {
- // FIXME: when logProbability methods are added to DiscreteDistribution in 4.0, remove cast below
final int testPoint = densityTestPoints[i];
Assertions.assertEquals(logDensityTestValues[i],
- ((AbstractDiscreteDistribution) distribution).logProbability(testPoint), tolerance,
+ distribution.logProbability(testPoint), tolerance,
() -> "Incorrect log density value returned for " + testPoint);
}
}
@@ -185,6 +245,57 @@
}
}
+ protected void verifySurvivalProbability() {
+ for (int i = 0; i < cumulativeTestPoints.length; i++) {
+ final int x = cumulativeTestPoints[i];
+ Assertions.assertEquals(
+ 1 - cumulativeTestValues[i],
+ distribution.survivalProbability(cumulativeTestPoints[i]),
+ getTolerance(),
+ () -> "Incorrect survival probability value returned for " + x);
+ }
+ }
+
+ protected void verifySurvivalAndCumulativeProbabilityComplement() {
+ for (final int x : cumulativeTestPoints) {
+ Assertions.assertEquals(
+ 1.0,
+ distribution.survivalProbability(x) + distribution.cumulativeProbability(x),
+ getTolerance(),
+ () -> "survival + cumulative probability were not close to 1.0 for " + x);
+ }
+ }
+
+ /**
+ * Verifies that survival is simply not 1-cdf by testing calculations that would underflow that calculation and
+ * result in an inaccurate answer.
+ */
+ protected void verifySurvivalProbabilityPrecision() {
+ for (int i = 0; i < survivalPrecisionTestPoints.length; i++) {
+ final int x = survivalPrecisionTestPoints[i];
+ Assertions.assertEquals(
+ survivalPrecisionTestValues[i],
+ distribution.survivalProbability(x),
+ getHighPrecisionTolerance(),
+ () -> "survival probability is not precise for " + x);
+ }
+ }
+
+ /**
+ * Verifies that CDF is simply not 1-survival function by testing values that would result with inaccurate results
+ * if simply calculating 1-survival function.
+ */
+ protected void verifyCumulativeProbabilityPrecision() {
+ for (int i = 0; i < cumulativePrecisionTestPoints.length; i++) {
+ final int x = cumulativePrecisionTestPoints[i];
+ Assertions.assertEquals(
+ cumulativePrecisionTestValues[i],
+ distribution.cumulativeProbability(x),
+ getHighPrecisionTolerance(),
+ () -> "cumulative probability is not precise for " + x);
+ }
+ }
+
/**
* Verifies that inverse cumulative probability density calculations match expected values
* using current test instance data.
@@ -227,6 +338,26 @@
verifyCumulativeProbabilities();
}
+ @Test
+ void testSurvivalProbability() {
+ verifySurvivalProbability();
+ }
+
+ @Test
+ void testSurvivalAndCumulativeProbabilitiesAreComplementary() {
+ verifySurvivalAndCumulativeProbabilityComplement();
+ }
+
+ @Test
+ void testCumulativeProbabilityPrecision() {
+ verifyCumulativeProbabilityPrecision();
+ }
+
+ @Test
+ void testSurvivalProbabilityPrecision() {
+ verifySurvivalProbabilityPrecision();
+ }
+
/**
* Verifies that inverse cumulative probability density calculations match expected values
* using default test instance data.
@@ -240,9 +371,11 @@
void testConsistencyAtSupportBounds() {
final int lower = distribution.getSupportLowerBound();
Assertions.assertEquals(0.0, distribution.cumulativeProbability(lower - 1), 0.0,
- "Cumulative probability mmust be 0 below support lower bound.");
+ "Cumulative probability must be 0 below support lower bound.");
Assertions.assertEquals(distribution.probability(lower), distribution.cumulativeProbability(lower), getTolerance(),
"Cumulative probability of support lower bound must be equal to probability mass at this point.");
+ Assertions.assertEquals(1.0, distribution.survivalProbability(lower - 1), 0.0,
+ "Survival probability must be 1.0 below support lower bound.");
Assertions.assertEquals(lower, distribution.inverseCumulativeProbability(0.0),
"Inverse cumulative probability of 0 must be equal to support lower bound.");
@@ -250,6 +383,8 @@
if (upper != Integer.MAX_VALUE) {
Assertions.assertEquals(1.0, distribution.cumulativeProbability(upper), 0.0,
"Cumulative probability of support upper bound must be equal to 1.");
+ Assertions.assertEquals(0.0, distribution.survivalProbability(upper), 0.0,
+ "Survival probability of support upper bound must be equal to 0.");
}
Assertions.assertEquals(upper, distribution.inverseCumulativeProbability(1.0),
"Inverse cumulative probability of 1 must be equal to support upper bound.");
@@ -357,10 +492,84 @@
}
/**
+ * Set the density test values.
+ * For convenience this recomputes the log density test values using {@link Math#log(double)}.
+ *
* @param densityTestValues The densityTestValues to set.
*/
protected void setDensityTestValues(double[] densityTestValues) {
this.densityTestValues = densityTestValues;
+ logDensityTestValues = Arrays.stream(densityTestValues).map(Math::log).toArray();
+ }
+
+ /**
+ * @return Returns the logDensityTestValues.
+ */
+ protected double[] getLogDensityTestValues() {
+ return logDensityTestValues;
+ }
+
+ /**
+ * @param logDensityTestValues The logDensityTestValues to set.
+ */
+ protected void setLogDensityTestValues(double[] logDensityTestValues) {
+ this.logDensityTestValues = logDensityTestValues;
+ }
+
+ /**
+ * @return Returns the cumulativePrecisionTestPoints.
+ */
+ protected int[] getCumulativePrecisionTestPoints() {
+ return cumulativePrecisionTestPoints;
+ }
+
+ /**
+ * @param cumulativePrecisionTestPoints The cumulativePrecisionTestPoints to set.
+ */
+ protected void setCumulativePrecisionTestPoints(int[] cumulativePrecisionTestPoints) {
+ this.cumulativePrecisionTestPoints = cumulativePrecisionTestPoints;
+ }
+
+ /**
+ * @return Returns the cumulativePrecisionTestValues.
+ */
+ protected double[] getCumulativePrecisionTestValues() {
+ return cumulativePrecisionTestValues;
+ }
+
+ /**
+ * @param cumulativePrecisionTestValues The cumulativePrecisionTestValues to set.
+ */
+ protected void setCumulativePrecisionTestValues(double[] cumulativePrecisionTestValues) {
+ this.cumulativePrecisionTestValues = cumulativePrecisionTestValues;
+ }
+
+ /**
+ * @return Returns the survivalPrecisionTestPoints.
+ */
+ protected int[] getSurvivalPrecisionTestPoints() {
+ return survivalPrecisionTestPoints;
+ }
+
+ /**
+ * @param survivalPrecisionTestPoints The survivalPrecisionTestPoints to set.
+ */
+ protected void setSurvivalPrecisionTestPoints(int[] survivalPrecisionTestPoints) {
+ this.survivalPrecisionTestPoints = survivalPrecisionTestPoints;
+ }
+
+ /**
+ * @return Returns the survivalPrecisionTestValues.
+ */
+ protected double[] getSurvivalPrecisionTestValues() {
+ return survivalPrecisionTestValues;
+ }
+
+ /**
+ * @param survivalPrecisionTestValues The survivalPrecisionTestValues to set.
+ */
+ protected void setSurvivalPrecisionTestValues(double[] survivalPrecisionTestValues) {
+ this.survivalPrecisionTestValues = survivalPrecisionTestValues;
}
/**
@@ -420,6 +629,20 @@
}
/**
+ * @return Returns the high precision tolerance.
+ */
+ protected double getHighPrecisionTolerance() {
+ return highPrecisionTolerance;
+ }
+
+ /**
+ * @param highPrecisionTolerance The high precision highPrecisionTolerance to set.
+ */
+ protected void setHighPrecisionTolerance(double highPrecisionTolerance) {
+ this.highPrecisionTolerance = highPrecisionTolerance;
+ }
+
+ /**
* The expected value for {@link DiscreteDistribution#isSupportConnected()}.
* The default is {@code true}. Test class should override this when the distribution
* is not support connected.
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java
index aa4960a..621f7aa 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java
@@ -40,7 +40,11 @@
}
@Override
public double cumulativeProbability(int x) {
- return 0;
+ // Return some different values to allow the survival probability to be tested
+ if (x < 0) {
+ return x < -5 ? 0.25 : 0.5;
+ }
+ return x > 5 ? 1.0 : 0.75;
}
@Override
public int inverseCumulativeProbability(double p) {
@@ -75,6 +79,8 @@
for (final int x : new int[] {Integer.MIN_VALUE, -1, 0, 1, 2, Integer.MAX_VALUE}) {
// Return the log of the density
Assertions.assertEquals(Math.log(x), dist.logProbability(x));
+ // Must return 1 - CDF(x)
+ Assertions.assertEquals(1.0 - dist.cumulativeProbability(x), dist.survivalProbability(x));
}
}
}
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java
index 45dc8d5..6efdb67 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java
@@ -138,6 +138,17 @@
};
}
+ @Override
+ public int[] makeSurvivalPrecisionTestPoints() {
+ return new int[] {74, 81};
+ }
+
+ @Override
+ public double[] makeSurvivalPrecisionTestValues() {
+ // computed using R version 3.4.4
+ return new double[] {2.2979669527522718895e-17, 6.4328367688565960968e-19};
+ }
+
//-------------------- Additional test cases -------------------------------
@Test
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java
index d2ce7ab..7dc251c 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java
@@ -101,7 +101,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {3, 3});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
Assertions.assertEquals(3, dist.getSupportLowerBound());
Assertions.assertEquals(3, dist.getSupportUpperBound());
@@ -119,7 +122,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {0, 0});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
Assertions.assertEquals(0, dist.getSupportLowerBound());
Assertions.assertEquals(0, dist.getSupportUpperBound());
@@ -137,7 +143,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {3, 3});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
Assertions.assertEquals(3, dist.getSupportLowerBound());
Assertions.assertEquals(3, dist.getSupportUpperBound());
@@ -321,4 +330,22 @@
Assertions.assertTrue(sample <= n, () -> "sample=" + sample);
}
}
+
+ @Test
+ void testHighPrecisionCumulativeProbabilities() {
+ // computed using R version 3.4.4
+ setDistribution(new HypergeometricDistribution(500, 70, 300));
+ setCumulativePrecisionTestPoints(new int[] {10, 8});
+ setCumulativePrecisionTestValues(new double[] {2.4055720603264525e-17, 1.2848174992266236e-19});
+ verifySurvivalProbabilityPrecision();
+ }
+
+ @Test
+ void testHighPrecisionSurvivalProbabilities() {
+ // computed using R version 3.4.4
+ setDistribution(new HypergeometricDistribution(500, 70, 300));
+ setSurvivalPrecisionTestPoints(new int[] {68, 69});
+ setSurvivalPrecisionTestValues(new double[] {4.570379934029859e-16, 7.4187180434325268e-18});
+ verifySurvivalProbabilityPrecision();
+ }
}
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java
index 66ecfe6..6d6467f 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java
@@ -78,6 +78,17 @@
return new int[] {0, 0, 0, 0, 1, 1, 14, 11, 10, 9, 8, Integer.MAX_VALUE};
}
+ @Override
+ public int[] makeSurvivalPrecisionTestPoints() {
+ return new int[] {47, 52};
+ }
+
+ @Override
+ public double[] makeSurvivalPrecisionTestValues() {
+ // computed using R version 3.4.4
+ return new double[] {3.1403888119656772712e-17, 1.7075879020163069251e-19};
+ }
+
//-------------------- Additional test cases -------------------------------
/** Test degenerate case p = 0 */
@@ -91,7 +102,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {Integer.MAX_VALUE, Integer.MAX_VALUE});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
}
@@ -106,7 +120,10 @@
setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d});
setInverseCumulativeTestValues(new int[] {0, 0});
verifyDensities();
+ verifyLogDensities();
verifyCumulativeProbabilities();
+ verifySurvivalProbability();
+ verifySurvivalAndCumulativeProbabilityComplement();
verifyInverseCumulativeProbabilities();
}
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java
index f30f697..b908926 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java
@@ -95,6 +95,17 @@
return new int[] {0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 10, 20};
}
+ @Override
+ public int[] makeSurvivalPrecisionTestPoints() {
+ return new int[] {30, 32};
+ }
+
+ @Override
+ public double[] makeSurvivalPrecisionTestValues() {
+ // computed using R version 3.4.4
+ return new double[] {1.1732435431464340474e-17, 1.7630174687875970627e-19};
+ }
+
//-------------------- Additional test cases -------------------------------
/**
@@ -221,4 +232,13 @@
mean *= 10.0;
}
}
+
+ @Test
+ void testLargeMeanHighPrecisionCumulativeProbabilities() {
+ // computed using R version 3.4.4
+ setDistribution(new PoissonDistribution(100));
+ setCumulativePrecisionTestPoints(new int[] {28, 25});
+ setCumulativePrecisionTestValues(new double[] {1.6858675763053070496e-17, 3.184075559619425735e-19});
+ verifyCumulativeProbabilityPrecision();
+ }
}
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/RegularizedBetaUtilsTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/RegularizedBetaUtilsTest.java
new file mode 100644
index 0000000..7756145
--- /dev/null
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/RegularizedBetaUtilsTest.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.statistics.distribution;
+
+import org.apache.commons.numbers.gamma.RegularizedBeta;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Test for {@link RegularizedBetaUtils}.
+ */
+class RegularizedBetaUtilsTest {
+ @Test
+ void testComplement() {
+ final double[] xs = {0, 0.1, 0.2, 0.25, 0.3, 1.0 / 3, 0.4, 0.5, 0.6, 2.0 / 3, 0.7, 0.75, 0.8, 0.9, 1};
+ // Called in PascalDistribution with a >= 1; b >= 1
+ // Called in BinomialDistribution with a >= 1; b >= 1
+ final double[] as = {1, 2, 3, 4, 5, 10, 20, 100, 1000};
+ final double[] bs = {1, 2, 3, 4, 5, 10, 20, 100, 1000};
+ for (final double x : xs) {
+ for (final double a : as) {
+ for (final double b : bs) {
+ assertComplement(x, a, b);
+ }
+ }
+ }
+ }
+
+ private static void assertComplement(double x, double a, double b) {
+ final double expected1 = 1.0 - RegularizedBeta.value(x, a, b);
+ final double expected2 = RegularizedBeta.value(1 - x, b, a);
+ final double actual = RegularizedBetaUtils.complement(x, a, b);
+ // Expect binary equality with 1 result
+ Assertions.assertTrue(expected1 == actual || expected2 == actual,
+ () -> String.format("I(%s, %s, %s) Expected %s or %s: Actual %s", x, a, b, expected1, expected2, actual));
+ }
+}