blob: 1a33921dce533e54e52e24d50a4b51ce8f9803ac [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling.distribution;
import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
/**
* List of samplers.
*/
public final class DiscreteSamplersList {
/** List of all RNGs implemented in the library. */
private static final List<DiscreteSamplerTestData> LIST = new ArrayList<>();
static {
try {
// This test uses reference distributions from commons-math3 to compute the expected
// PMF. These distributions have a dual functionality to compute the PMF and perform
// sampling. When no sampling is needed for the created distribution, it is advised
// to pass null as the random generator via the appropriate constructors to avoid the
// additional initialisation overhead.
org.apache.commons.math3.random.RandomGenerator unusedRng = null;
// List of distributions to test.
// Binomial ("inverse method").
final int trialsBinomial = 20;
final double probSuccessBinomial = 0.67;
add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
MathArrays.sequence(8, 9, 1),
RandomSource.KISS.create());
add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
// range [9,16]
MathArrays.sequence(8, 9, 1),
MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomSource.WELL_19937_A.create(), trialsBinomial, probSuccessBinomial));
// Inverted
add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial),
// range [4,11] = [20-16, 20-9]
MathArrays.sequence(8, 4, 1),
MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomSource.WELL_19937_C.create(), trialsBinomial, 1 - probSuccessBinomial));
// Geometric ("inverse method").
final double probSuccessGeometric = 0.21;
add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
MathArrays.sequence(10, 0, 1),
RandomSource.ISAAC.create());
// Geometric.
add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
MathArrays.sequence(10, 0, 1),
GeometricSampler.of(RandomSource.XOR_SHIFT_1024_S_PHI.create(), probSuccessGeometric));
// Hypergeometric ("inverse method").
final int popSizeHyper = 34;
final int numSuccessesHyper = 11;
final int sampleSizeHyper = 12;
add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(unusedRng, popSizeHyper, numSuccessesHyper, sampleSizeHyper),
MathArrays.sequence(10, 0, 1),
RandomSource.MT.create());
// Pascal ("inverse method").
final int numSuccessesPascal = 6;
final double probSuccessPascal = 0.2;
add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(unusedRng, numSuccessesPascal, probSuccessPascal),
MathArrays.sequence(18, 1, 1),
RandomSource.TWO_CMRES.create());
// Uniform ("inverse method").
final int loUniform = -3;
final int hiUniform = 4;
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
MathArrays.sequence(8, -3, 1),
RandomSource.SPLIT_MIX_64.create());
// Uniform (power of 2 range).
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
MathArrays.sequence(8, -3, 1),
DiscreteUniformSampler.of(RandomSource.MT_64.create(), loUniform, hiUniform));
// Uniform (large range).
final int halfMax = Integer.MAX_VALUE / 2;
final int hiLargeUniform = halfMax + 10;
final int loLargeUniform = -hiLargeUniform;
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
MathArrays.sequence(20, -halfMax, halfMax / 10),
DiscreteUniformSampler.of(RandomSource.WELL_1024_A.create(), loLargeUniform, hiLargeUniform));
// Uniform (non-power of 2 range).
final int rangeNonPowerOf2Uniform = 11;
final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
DiscreteUniformSampler.of(RandomSource.XO_SHI_RO_256_SS.create(), loUniform, hiNonPowerOf2Uniform));
// Zipf ("inverse method").
final int numElementsZipf = 5;
final double exponentZipf = 2.345;
add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
MathArrays.sequence(5, 1, 1),
RandomSource.XOR_SHIFT_1024_S_PHI.create());
// Zipf.
add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
MathArrays.sequence(5, 1, 1),
RejectionInversionZipfSampler.of(RandomSource.WELL_19937_C.create(), numElementsZipf, exponentZipf));
// Zipf (exponent close to 1).
final double exponentCloseToOneZipf = 1 - 1e-10;
add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentCloseToOneZipf),
MathArrays.sequence(5, 1, 1),
RejectionInversionZipfSampler.of(RandomSource.WELL_19937_C.create(), numElementsZipf, exponentCloseToOneZipf));
// Zipf (exponent = 0).
add(LIST, MathArrays.sequence(5, 1, 1), new double[] {0.2, 0.2, 0.2, 0.2, 0.2},
RejectionInversionZipfSampler.of(RandomSource.XO_RO_SHI_RO_128_PP.create(), numElementsZipf, 0.0));
// Poisson ("inverse method").
final double epsilonPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_EPSILON;
final int maxIterationsPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_MAX_ITERATIONS;
final double meanPoisson = 3.21;
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(10, 0, 1),
RandomSource.MWC_256.create());
// Poisson.
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(10, 0, 1),
PoissonSampler.of(RandomSource.KISS.create(), meanPoisson));
// Dedicated small mean poisson samplers
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(10, 0, 1),
SmallMeanPoissonSampler.of(RandomSource.XO_SHI_RO_256_PLUS.create(), meanPoisson));
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(10, 0, 1),
KempSmallMeanPoissonSampler.of(RandomSource.XO_SHI_RO_128_PLUS.create(), meanPoisson));
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(10, 0, 1),
MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.XO_SHI_RO_128_PLUS.create(), meanPoisson));
// LargeMeanPoissonSampler should work at small mean.
// Note: This hits a code path where the sample from the normal distribution is rejected.
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(10, 0, 1),
LargeMeanPoissonSampler.of(RandomSource.PCG_MCG_XSH_RR_32.create(), meanPoisson));
// Poisson (40 < mean < 80).
final double largeMeanPoisson = 67.89;
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
PoissonSampler.of(RandomSource.SPLIT_MIX_64.create(), largeMeanPoisson));
// Dedicated large mean poisson sampler
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
LargeMeanPoissonSampler.of(RandomSource.SPLIT_MIX_64.create(), largeMeanPoisson));
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.XO_RO_SHI_RO_128_PLUS.create(), largeMeanPoisson));
// Poisson (mean >> 40).
final double veryLargeMeanPoisson = 543.21;
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
PoissonSampler.of(RandomSource.SPLIT_MIX_64.create(), veryLargeMeanPoisson));
// Dedicated large mean poisson sampler
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
LargeMeanPoissonSampler.of(RandomSource.SPLIT_MIX_64.create(), veryLargeMeanPoisson));
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.XO_RO_SHI_RO_64_SS.create(), veryLargeMeanPoisson));
// Any discrete distribution
final int[] discretePoints = {0, 1, 2, 3, 4};
final double[] discreteProbabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
final long[] discreteFrequencies = {1, 2, 3, 4, 5};
add(LIST, discretePoints, discreteProbabilities,
MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomSource.XO_SHI_RO_512_PLUS.create(), discreteProbabilities));
add(LIST, discretePoints, discreteProbabilities,
GuideTableDiscreteSampler.of(RandomSource.XO_SHI_RO_512_SS.create(), discreteProbabilities));
add(LIST, discretePoints, discreteProbabilities,
AliasMethodDiscreteSampler.of(RandomSource.KISS.create(), discreteProbabilities));
add(LIST, discretePoints, discreteProbabilities,
FastLoadedDiceRollerDiscreteSampler.of(RandomSource.L64_X128_MIX.create(), discreteFrequencies));
add(LIST, discretePoints, discreteProbabilities,
FastLoadedDiceRollerDiscreteSampler.of(RandomSource.L64_X128_SS.create(), discreteProbabilities));
} catch (Exception e) {
// CHECKSTYLE: stop Regexp
System.err.println("Unexpected exception while creating the list of samplers: " + e);
e.printStackTrace(System.err);
// CHECKSTYLE: resume Regexp
throw new RuntimeException(e);
}
}
/**
* Class contains only static methods.
*/
private DiscreteSamplersList() {}
/**
* @param list List of data (one the "parameters" tested by the Junit parametric test).
* @param dist Distribution to which the samples are supposed to conform.
* @param points Outcomes selection.
* @param rng Generator of uniformly distributed sequences.
*/
private static void add(List<DiscreteSamplerTestData> list,
final org.apache.commons.math3.distribution.IntegerDistribution dist,
int[] points,
UniformRandomProvider rng) {
final DiscreteSampler inverseMethodSampler =
InverseTransformDiscreteSampler.of(rng,
new DiscreteInverseCumulativeProbabilityFunction() {
@Override
public int inverseCumulativeProbability(double p) {
return dist.inverseCumulativeProbability(p);
}
@Override
public String toString() {
return dist.toString();
}
});
list.add(new DiscreteSamplerTestData(inverseMethodSampler,
points,
getProbabilities(dist, points)));
}
/**
* @param list List of data (one the "parameters" tested by the Junit parametric test).
* @param dist Distribution to which the samples are supposed to conform.
* @param points Outcomes selection.
* @param sampler Sampler.
*/
private static void add(List<DiscreteSamplerTestData> list,
final org.apache.commons.math3.distribution.IntegerDistribution dist,
int[] points,
final DiscreteSampler sampler) {
list.add(new DiscreteSamplerTestData(sampler,
points,
getProbabilities(dist, points)));
}
/**
* @param list List of data (one the "parameters" tested by the Junit parametric test).
* @param points Outcomes selection.
* @param probabilities Probability distribution to which the samples are supposed to conform.
* @param sampler Sampler.
*/
private static void add(List<DiscreteSamplerTestData> list,
int[] points,
final double[] probabilities,
final DiscreteSampler sampler) {
list.add(new DiscreteSamplerTestData(sampler,
points,
probabilities));
}
/**
* Subclasses that are "parametric" tests can forward the call to
* the "@Parameters"-annotated method to this method.
*
* @return the list of all generators.
*/
public static Iterable<DiscreteSamplerTestData> list() {
return Collections.unmodifiableList(LIST);
}
/**
* @param dist Distribution.
* @param points Points.
* @return the probabilities of the given points according to the distribution.
*/
private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
int[] points) {
final int len = points.length;
final double[] prob = new double[len];
for (int i = 0; i < len; i++) {
prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ? // XXX Workaround.
getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
dist.probability(points[i]);
if (prob[i] < 0) {
throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
}
}
return prob;
}
/**
* Workaround bugs in Commons Math's "UniformIntegerDistribution" (cf. MATH-1396).
*/
private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
}
}