blob: 6b56fec07ad7e2a9b01639f4ca58eb6862d9a292 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.statistics.distribution;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.numbers.gamma.LanczosApproximation;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/**
* Test cases for {@link GammaDistribution}.
* Extends {@link BaseContinuousDistributionTest}. See javadoc of that class for details.
*/
class GammaDistributionTest extends BaseContinuousDistributionTest {
private static final double HALF_LOG_2_PI = 0.5 * Math.log(2.0 * Math.PI);
@Override
ContinuousDistribution makeDistribution(Object... parameters) {
final double shape = (Double) parameters[0];
final double scale = (Double) parameters[1];
return GammaDistribution.of(shape, scale);
}
@Override
Object[][] makeInvalidParameters() {
return new Object[][] {
{0.0, 1.0},
{-0.1, 1.0},
{1.0, 0.0},
{1.0, -0.1},
};
}
@Override
String[] getParameterNames() {
return new String[] {"Shape", "Scale"};
}
@Override
protected double getRelativeTolerance() {
return 1e-15;
}
//-------------------- Additional test cases -------------------------------
@Test
void testAdditionalMoments() {
GammaDistribution dist;
dist = GammaDistribution.of(1, 2);
Assertions.assertEquals(2, dist.getMean());
Assertions.assertEquals(4, dist.getVariance());
dist = GammaDistribution.of(1.1, 4.2);
Assertions.assertEquals(1.1 * 4.2, dist.getMean());
Assertions.assertEquals(1.1 * 4.2 * 4.2, dist.getVariance());
}
@Test
void testAdditionalCumulativeProbability() {
checkCumulativeProbability(-1.000, 4.0, 2.0, 0.0);
checkCumulativeProbability(15.501, 4.0, 2.0, 0.94989465156755404);
checkCumulativeProbability(0.504, 4.0, 1.0, 0.0018026739713985257);
checkCumulativeProbability(10.011, 1.0, 2.0, 0.99329900998454213);
checkCumulativeProbability(5.000, 2.0, 2.0, 0.71270250481635422);
}
private void checkCumulativeProbability(double x, double a, double b, double expected) {
final GammaDistribution distribution = GammaDistribution.of(a, b);
final double actual = distribution.cumulativeProbability(x);
Assertions.assertEquals(expected, actual, expected * 1e-15, () -> "probability for " + x);
}
@Test
void testAdditionalInverseCumulativeProbability() {
checkInverseCumulativeProbability(15.501, 4.0, 2.0, 0.94989465156755404);
checkInverseCumulativeProbability(0.504, 4.0, 1.0, 0.0018026739713985257);
checkInverseCumulativeProbability(10.011, 1.0, 2.0, 0.99329900998454213);
checkInverseCumulativeProbability(5.000, 2.0, 2.0, 0.71270250481635422);
}
private void checkInverseCumulativeProbability(double expected, double a, double b, double p) {
final GammaDistribution distribution = GammaDistribution.of(a, b);
final double actual = distribution.inverseCumulativeProbability(p);
Assertions.assertEquals(expected, actual, expected * 1e-10, () -> "critical value for " + p);
}
@Test
void testAdditionalDensity() {
final double[] x = new double[]{-0.1, 1e-6, 0.5, 1, 2, 5};
// R2.5: print(dgamma(x, shape=1, rate=1), digits=10)
checkDensity(1, 1, x, new double[]{0.000000000000, 0.999999000001, 0.606530659713, 0.367879441171, 0.135335283237, 0.006737946999});
// R2.5: print(dgamma(x, shape=2, rate=1), digits=10)
checkDensity(2, 1, x, new double[]{0.000000000000, 0.000000999999, 0.303265329856, 0.367879441171, 0.270670566473, 0.033689734995});
// R2.5: print(dgamma(x, shape=4, rate=1), digits=10)
checkDensity(4, 1, x, new double[]{0.000000000e+00, 1.666665000e-19, 1.263605541e-02, 6.131324020e-02, 1.804470443e-01, 1.403738958e-01});
// R2.5: print(dgamma(x, shape=4, rate=10), digits=10)
checkDensity(4, 10, x, new double[]{0.000000000e+00, 1.666650000e-15, 1.403738958e+00, 7.566654960e-02, 2.748204830e-05, 4.018228850e-17});
// R2.5: print(dgamma(x, shape=.1, rate=10), digits=10)
checkDensity(0.1, 10, x, new double[]{0.000000000e+00, 3.323953832e+04, 1.663849010e-03, 6.007786726e-06, 1.461647647e-10, 5.996008322e-24});
// R2.5: print(dgamma(x, shape=.1, rate=20), digits=10)
checkDensity(0.1, 20, x, new double[]{0.000000000e+00, 3.562489883e+04, 1.201557345e-05, 2.923295295e-10, 3.228910843e-19, 1.239484589e-45});
// R2.5: print(dgamma(x, shape=.1, rate=4), digits=10)
checkDensity(0.1, 4, x, new double[]{0.000000000e+00, 3.032938388e+04, 3.049322494e-02, 2.211502311e-03, 2.170613371e-05, 5.846590589e-11});
// R2.5: print(dgamma(x, shape=.1, rate=1), digits=10)
checkDensity(0.1, 1, x, new double[]{0.000000000e+00, 2.640334143e+04, 1.189704437e-01, 3.866916944e-02, 7.623306235e-03, 1.663849010e-04});
// To force overflow condition
// R2.5: print(dgamma(x, shape=1000, rate=100), digits=10)
checkDensity(1000, 100, x, new double[]{0.000000000e+00, 0.000000000e+00, 0.000000000e+00, 0.000000000e+00, 0.000000000e+00, 3.304830256e-84});
}
/**
* Test a shape far below 1. Support for very small shape parameters
* was fixed in STATISTICS-39.
*/
@Test
void testShapeBelow1() {
// R2.5: print(dgamma(x, shape=0.05, rate=1), digits=20)
final double[] x = new double[]{1e-100, 1e-10, 1e-5, 0.1};
checkDensity(0.05, 1, x, new double[] {
5.1360843263583843333e+93, 1.6241724724359893799e+08,
2.8882035841935007738e+03, 4.1419294512123655538e-01
});
}
private void checkDensity(double alpha, double rate, double[] x, double[] expected) {
final GammaDistribution dist = GammaDistribution.of(alpha, 1 / rate);
for (int i = 0; i < x.length; i++) {
Assertions.assertEquals(expected[i], dist.density(x[i]), Math.abs(expected[i]) * 1e-9);
}
}
@Test
void testAdditionalLogDensity() {
final double[] x = new double[]{-0.1, 1e-6, 0.5, 1, 2, 5};
final double inf = Double.POSITIVE_INFINITY;
// R2.5: print(dgamma(x, shape=1, rate=1, log=TRUE), digits=10)
checkLogDensity(1, 1, x, new double[]{-inf, -1e-06, -5e-01, -1e+00, -2e+00, -5e+00});
// R2.5: print(dgamma(x, shape=2, rate=1, log=TRUE), digits=10)
checkLogDensity(2, 1, x, new double[]{-inf, -13.815511558, -1.193147181, -1.000000000, -1.306852819, -3.390562088});
// R2.5: print(dgamma(x, shape=4, rate=1, log=TRUE), digits=10)
checkLogDensity(4, 1, x, new double[]{-inf, -43.238292143, -4.371201011, -2.791759469, -1.712317928, -1.963445732});
// R2.5: print(dgamma(x, shape=4, rate=10, log=TRUE), digits=10)
checkLogDensity(4, 10, x, new double[]{-inf, -34.0279607711, 0.3391393611, -2.5814190973, -10.5019775556, -37.7531053599});
// R2.5: print(dgamma(x, shape=.1, rate=10, log=TRUE), digits=10)
checkLogDensity(0.1, 10, x, new double[]{-inf, 10.41149536, -6.39862168, -12.02245414, -22.64628660, -53.47094826});
// R2.5: print(dgamma(x, shape=.1, rate=20, log=TRUE), digits=10)
checkLogDensity(0.1, 20, x, new double[]{-inf, 10.48080008, -11.32930696, -21.95313942, -42.57697189, -103.40163355});
// R2.5: print(dgamma(x, shape=.1, rate=4, log=TRUE), digits=10)
checkLogDensity(0.1, 4, x, new double[]{-inf, 10.319872287, -3.490250753, -6.114083216, -10.737915678, -23.562577337});
// R2.5: print(dgamma(x, shape=.1, rate=1, log=TRUE), digits=10)
checkLogDensity(0.1, 1, x, new double[]{-inf, 10.181245850, -2.128880189, -3.252712652, -4.876545114, -8.701206773});
// To force overflow condition to pdf=zero
// R2.5: print(dgamma(x, shape=1000, rate=100, log=TRUE), digits=10)
checkLogDensity(1000, 100, x, new double[]{-inf, -15101.7453846, -2042.5042706, -1400.0502372, -807.5962038, -192.2217627});
// To force overflow condition to pdf=infinity
final double[] x1 = {1e-315, 1e-320, 1e-323};
// scipy.stats gamma(1e-2).logpdf(x1)
checkLogDensity(0.01, 1, x1, new double[]{713.46168137365419, 724.85948860402209, 731.70997561537104});
}
private void checkLogDensity(double alpha, double rate, double[] x, double[] expected) {
final GammaDistribution dist = GammaDistribution.of(alpha, 1 / rate);
for (int i = 0; i < x.length; i++) {
Assertions.assertEquals(expected[i], dist.logDensity(x[i]), Math.abs(expected[i]) * 1e-9);
}
}
private static double logGamma(double x) {
/*
* This is a copy of
* double Gamma.logGamma(double)
* prior to MATH-849
*/
if (Double.isNaN(x) || x <= 0.0) {
return Double.NaN;
}
final double sum = LanczosApproximation.value(x);
final double tmp = x + LanczosApproximation.g() + .5;
return ((x + .5) * Math.log(tmp)) - tmp +
HALF_LOG_2_PI + Math.log(sum / x);
}
private static double density(final double x,
final double shape,
final double scale) {
/*
* This is a copy of
* double GammaDistribution.density(double)
* prior to MATH-753.
*/
if (x < 0) {
return 0;
}
return Math.pow(x / scale, shape - 1) / scale *
Math.exp(-x / scale) / Math.exp(logGamma(shape));
}
/**
* MATH-753: large values of x or shape parameter cause density(double) to
* overflow. Reference data is generated with the Maxima script
* gamma-distribution.mac, which can be found in
* src/test/resources/org/apache/commons/statistics/distribution.
*
* @param shape Shape of gamma distribution (scale is assumed to be 1)
* @param meanNoOF Allowed mean ULP error when the computed value does not overflow using the old method
* @param sdNoOF Allowed SD ULP error when the computed value does not overflow using the old method
* @param meanOF Allowed mean ULP error when the computed value overflows using the old method
* @param sdOF Allowed SD ULP error when the computed value overflows using the old method
* @param resourceName Resource name containing a pair of comma separated values for the
* random variable x and the expected value of the gamma distribution: x, gamma(x; shape, scale=1)
*/
private void doTestMath753(final double shape,
final double meanNoOF, final double sdNoOF,
final double meanOF, final double sdOF,
final String resourceName) {
final GammaDistribution distribution = GammaDistribution.of(shape, 1.0);
final SummaryStatistics statOld = new SummaryStatistics();
// statNewNoOF = No overflow of old function
// statNewOF = Overflow of old function
final SummaryStatistics statNewNoOF = new SummaryStatistics();
final SummaryStatistics statNewOF = new SummaryStatistics();
final InputStream resourceAsStream = this.getClass().getResourceAsStream(resourceName);
Assertions.assertNotNull(resourceAsStream, () -> "Could not find resource " + resourceName);
try (BufferedReader in = new BufferedReader(new InputStreamReader(resourceAsStream))) {
for (String line = in.readLine(); line != null; line = in.readLine()) {
if (line.startsWith("#")) {
continue;
}
final String[] tokens = line.split(", ");
Assertions.assertEquals(2, tokens.length, "expected two floating-point values");
final double x = Double.parseDouble(tokens[0]);
final String msg = "x = " + x + ", shape = " + shape +
", scale = 1.0";
final double expected = Double.parseDouble(tokens[1]);
final double ulp = Math.ulp(expected);
final double actualOld = density(x, shape, 1.0);
final double actualNew = distribution.density(x);
final double errOld = Math.abs((actualOld - expected) / ulp);
final double errNew = Math.abs((actualNew - expected) / ulp);
if (!Double.isFinite(actualOld)) {
Assertions.assertTrue(Double.isFinite(actualNew), msg);
statNewOF.addValue(errNew);
} else {
statOld.addValue(errOld);
statNewNoOF.addValue(errNew);
}
}
if (statOld.getN() != 0) {
/*
* If no overflow occurs, check that new implementation is
* better than old one.
*/
final StringBuilder sb = new StringBuilder("shape = ");
sb.append(shape);
sb.append(", scale = 1.0\n");
sb.append("Old implementation\n");
sb.append("------------------\n");
sb.append(statOld.toString());
sb.append("New implementation\n");
sb.append("------------------\n");
sb.append(statNewNoOF.toString());
final String msg = sb.toString();
final double oldMin = statOld.getMin();
final double newMin = statNewNoOF.getMin();
Assertions.assertTrue(newMin <= oldMin, msg);
final double oldMax = statOld.getMax();
final double newMax = statNewNoOF.getMax();
Assertions.assertTrue(newMax <= oldMax, msg);
final double oldMean = statOld.getMean();
final double newMean = statNewNoOF.getMean();
Assertions.assertTrue(newMean <= oldMean, msg);
final double oldSd = statOld.getStandardDeviation();
final double newSd = statNewNoOF.getStandardDeviation();
Assertions.assertTrue(newSd <= oldSd, msg);
Assertions.assertTrue(newMean <= meanNoOF, msg);
Assertions.assertTrue(newSd <= sdNoOF, msg);
}
if (statNewOF.getN() != 0) {
final double newMean = statNewOF.getMean();
final double newSd = statNewOF.getStandardDeviation();
final StringBuilder sb = new StringBuilder("shape = ");
sb.append(shape);
sb.append(", scale = 1.0");
sb.append(", max. mean error (ulps) = ");
sb.append(meanOF);
sb.append(", actual mean error (ulps) = ");
sb.append(newMean);
sb.append(", max. sd of error (ulps) = ");
sb.append(sdOF);
sb.append(", actual sd of error (ulps) = ");
sb.append(newSd);
final String msg = sb.toString();
Assertions.assertTrue(newMean <= meanOF, msg);
Assertions.assertTrue(newSd <= sdOF, msg);
}
} catch (final IOException e) {
Assertions.fail(e);
}
}
@Test
void testMath753Shape025() {
doTestMath753(0.25, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-0.25.csv");
}
@Test
void testMath753Shape05() {
doTestMath753(0.5, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-0.5.csv");
}
@Test
void testMath753Shape075() {
doTestMath753(0.75, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-0.75.csv");
}
@Test
void testMath753Shape1() {
doTestMath753(1.0, 1.0, 0.5, 0.0, 0.0, "gamma-distribution-shape-1.csv");
}
@Test
void testMath753Shape8() {
doTestMath753(8.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-8.csv");
}
@Test
void testMath753Shape10() {
doTestMath753(10.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-10.csv");
}
@Test
void testMath753Shape100() {
doTestMath753(100.0, 2.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-100.csv");
}
@Test
void testMath753Shape142() {
doTestMath753(142.0, 1.5, 1.0, 40.0, 40.0, "gamma-distribution-shape-142.csv");
}
@Test
void testMath753Shape1000() {
doTestMath753(1000.0, 1.0, 1.0, 160.0, 200.0, "gamma-distribution-shape-1000.csv");
}
}