RNG-179: Add fast loaded dice roller discrete sampler
diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java
index bc55377..b91969c 100644
--- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java
@@ -22,7 +22,9 @@
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.DirichletSampler;
import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.FastLoadedDiceRollerDiscreteSampler;
import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
import org.apache.commons.rng.simple.RandomSource;
@@ -41,7 +43,6 @@
import org.openjdk.jmh.annotations.Warmup;
import java.util.Arrays;
-import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
@@ -115,6 +116,9 @@
"AliasMethodDiscreteSampler",
"GuideTableDiscreteSampler",
"MarsagliaTsangWangDiscreteSampler",
+ "FastLoadedDiceRollerDiscreteSampler",
+ "FastLoadedDiceRollerDiscreteSamplerLong",
+ "FastLoadedDiceRollerDiscreteSampler53",
// Uncomment to test non-default parameters
//"AliasMethodDiscreteSamplerNoPad", // Not optimal for sampling
@@ -187,6 +191,19 @@
factory = () -> GuideTableDiscreteSampler.of(rng, probabilities, 8);
} else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
factory = () -> MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
+ } else if ("FastLoadedDiceRollerDiscreteSampler".equals(samplerType)) {
+ factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities);
+ } else if ("FastLoadedDiceRollerDiscreteSamplerLong".equals(samplerType)) {
+ // Avoid exact floating-point arithmetic in construction.
+ // Frequencies must sum to less than 2^63; here the sum is ~2^62.
+ // This conversion may omit very small probabilities.
+ final double sum = Arrays.stream(probabilities).sum();
+ final long[] frequencies = Arrays.stream(probabilities)
+ .mapToLong(x -> Math.round(0x1.0p62 * x / sum))
+ .toArray();
+ factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
+ } else if ("FastLoadedDiceRollerDiscreteSampler53".equals(samplerType)) {
+ factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, 53);
} else {
throw new IllegalStateException();
}
@@ -335,12 +352,115 @@
/** {@inheritDoc} */
@Override
protected double[] createProbabilities() {
- final double[] probabilities = new double[randomNonUniformSize];
- final ThreadLocalRandom rng = ThreadLocalRandom.current();
- for (int i = 0; i < probabilities.length; i++) {
- probabilities[i] = rng.nextDouble();
- }
- return probabilities;
+ return RandomSource.XO_RO_SHI_RO_128_PP.create()
+ .doubles(randomNonUniformSize).toArray();
+ }
+ }
+
+ /**
+ * Sample random probability arrays from a Dirichlet distribution.
+ *
+ * <p>The distribution ensures the probabilities sum to 1.
+ * The <a href="https://en.wikipedia.org/wiki/Entropy_(information_theory)">entropy</a>
+ * of the probabilities increases with parameters k and alpha.
+ * The following shows the mean and sd of the entropy from 100 samples
+ * for a range of parameters.
+ * <pre>
+ * k alpha mean sd
+ * 4 0.500 1.299 0.374
+ * 4 1.000 1.531 0.294
+ * 4 2.000 1.754 0.172
+ * 8 0.500 2.087 0.348
+ * 8 1.000 2.490 0.266
+ * 8 2.000 2.707 0.142
+ * 16 0.500 3.023 0.287
+ * 16 1.000 3.454 0.166
+ * 16 2.000 3.693 0.095
+ * 32 0.500 4.008 0.182
+ * 32 1.000 4.406 0.125
+ * 32 2.000 4.692 0.075
+ * 64 0.500 4.986 0.151
+ * 64 1.000 5.392 0.115
+ * 64 2.000 5.680 0.048
+ * </pre>
+ */
+ @State(Scope.Benchmark)
+ public static class DirichletDistributionSources extends SamplerSources {
+ /** Number of categories. */
+ @Param({"4", "8", "16"})
+ private int k;
+
+ /** Concentration parameter. */
+ @Param({"0.5", "1", "2"})
+ private double alpha;
+
+ /** {@inheritDoc} */
+ @Override
+ protected double[] createProbabilities() {
+ return DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(),
+ k, alpha).sample();
+ }
+ }
+
+ /**
+ * The {@link FastLoadedDiceRollerDiscreteSampler} samplers to use for testing.
+ * Creates the sampler for each random source and the probabilities using
+ * a Dirichlet distribution.
+ *
+ * <p>This class is a specialized source to allow examination of the effect of the
+ * {@link FastLoadedDiceRollerDiscreteSampler} {@code alpha} parameter.
+ */
+ @State(Scope.Benchmark)
+ public static class FastLoadedDiceRollerDiscreteSamplerSources extends LocalRandomSources {
+ /** Number of categories. */
+ @Param({"4", "8", "16"})
+ private int k;
+
+ /** Concentration parameter. */
+ @Param({"0.5", "1", "2"})
+ private double concentration;
+
+ /** The constructor {@code alpha} parameter. */
+ @Param({"0", "30", "53"})
+ private int alpha;
+
+ /** The factory. */
+ private Supplier<DiscreteSampler> factory;
+
+ /** The sampler. */
+ private DiscreteSampler sampler;
+
+ /**
+ * Gets the sampler.
+ *
+ * @return the sampler.
+ */
+ public DiscreteSampler getSampler() {
+ return sampler;
+ }
+
+ /** Create the distribution probabilities (per iteration as it may vary), the sampler
+ * factory and instantiates sampler. */
+ @Override
+ @Setup(Level.Iteration)
+ public void setup() {
+ super.setup();
+
+ final double[] probabilities =
+ DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(),
+ k, concentration).sample();
+ final UniformRandomProvider rng = getGenerator();
+ factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, alpha);
+ sampler = factory.get();
+ }
+
+ /**
+ * Creates a new instance of the sampler.
+ *
+ * @return The sampler.
+ */
+ public DiscreteSampler createSampler() {
+ return factory.get();
}
}
@@ -480,7 +600,7 @@
}
/**
- * Run the sampler.
+ * Create and run the sampler.
*
* @param sources Source of randomness.
* @return the sample value
@@ -502,7 +622,7 @@
}
/**
- * Run the sampler.
+ * Create and run the sampler.
*
* @param sources Source of randomness.
* @return the sample value
@@ -511,4 +631,48 @@
public int singleSampleRandom(RandomDistributionSources sources) {
return sources.createSampler().sample();
}
+
+ /**
+ * Run the sampler.
+ *
+ * @param sources Source of randomness.
+ * @return the sample value
+ */
+ @Benchmark
+ public int sampleDirichlet(DirichletDistributionSources sources) {
+ return sources.getSampler().sample();
+ }
+
+ /**
+ * Create and run the sampler.
+ *
+ * @param sources Source of randomness.
+ * @return the sample value
+ */
+ @Benchmark
+ public int singleSampleDirichlet(DirichletDistributionSources sources) {
+ return sources.createSampler().sample();
+ }
+
+ /**
+ * Run the sampler.
+ *
+ * @param sources Source of randomness.
+ * @return the sample value
+ */
+ @Benchmark
+ public int sampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) {
+ return sources.getSampler().sample();
+ }
+
+ /**
+ * Create and run the sampler.
+ *
+ * @param sources Source of randomness.
+ * @return the sample value
+ */
+ @Benchmark
+ public int singleSampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) {
+ return sources.createSampler().sample();
+ }
}
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSampler.java
new file mode 100644
index 0000000..625a73c
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSampler.java
@@ -0,0 +1,856 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import java.math.BigInteger;
+import java.util.Arrays;
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to
+ * sample from {@code n} values each with an associated relative weight. If all unique items
+ * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}.
+ *
+ * <p>Given a list {@code L} of {@code n} positive numbers,
+ * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns
+ * integer {@code i} with relative probability {@code L[i]}.
+ *
+ * <p>FLDR produces <em>exact</em> samples from the specified probability distribution.
+ * <ul>
+ * <li>For integer weights, the probability of returning {@code i} is precisely equal to the
+ * rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}.
+ * <li>For floating-points weights, each weight {@code L[i]} is converted to the
+ * corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and
+ * {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
+ * </ul>
+ *
+ * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that
+ * ignores very small relative weights may have improved sampling performance.
+ *
+ * <p>This implementation is based on the algorithm in:
+ *
+ * <blockquote>
+ * Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
+ * The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability
+ * Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on
+ * Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108,
+ * Palermo, Sicily, Italy, 2020.
+ * </blockquote>
+ *
+ * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits.
+ *
+ * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020)
+ * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics,
+ * PMLR 108:1036-1046.</a>
+ * @since 1.5
+ */
+public abstract class FastLoadedDiceRollerDiscreteSampler
+ implements SharedStateDiscreteSampler {
+ /**
+ * The maximum size of an array.
+ *
+ * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}.
+ * It allows VMs to reserve some header words in an array.
+ */
+ private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
+ /** The maximum biased exponent for a finite double.
+ * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */
+ private static final int MAX_BIASED_EXPONENT = 2046;
+ /** Size of the mantissa of a double. Equal to 52 bits. */
+ private static final int MANTISSA_SIZE = 52;
+ /** Mask to extract the 52-bit mantissa from a long representation of a double. */
+ private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL;
+ /** BigInteger representation of {@link Long#MAX_VALUE}. */
+ private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE);
+ /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.
+ * The value will remain positive for any shift {@code <=} this value. */
+ private static final int MAX_OFFSET = 10;
+ /** Initial value for no leaf node label. */
+ private static final int NO_LABEL = Integer.MAX_VALUE;
+ /** Name of the sampler. */
+ private static final String SAMPLER_NAME = "Fast Loaded Dice Roller";
+
+ /**
+ * Class to handle the edge case of observations in only one category.
+ */
+ private static class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler {
+ /** The sample value. */
+ private final int sampleValue;
+
+ /**
+ * @param sampleValue Sample value.
+ */
+ FixedValueDiscreteSampler(int sampleValue) {
+ this.sampleValue = sampleValue;
+ }
+
+ @Override
+ public int sample() {
+ return sampleValue;
+ }
+
+ @Override
+ public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ return SAMPLER_NAME;
+ }
+ }
+
+ /**
+ * Class to implement the FLDR sample algorithm.
+ */
+ private static class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler {
+ /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on
+ * the boolean source. */
+ private static final int EMPTY_BOOL_SOURCE = 1;
+
+ /** Underlying source of randomness. */
+ private final UniformRandomProvider rng;
+ /** Number of categories. */
+ private final int n;
+ /** Number of levels in the discrete distribution generating (DDG) tree.
+ * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */
+ private final int k;
+ /** Number of leaf nodes at each level. */
+ private final int[] h;
+ /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */
+ private final int[] lH;
+
+ /**
+ * Provides a bit source for booleans.
+ *
+ * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}.
+ *
+ * <p>Only stores 31-bits when full as 1 bit has already been consumed.
+ * The sign bit is a flag that shifts down so the source eventually equals 1
+ * when all bits are consumed and will trigger a refill.
+ */
+ private int booleanSource = EMPTY_BOOL_SOURCE;
+
+ /**
+ * Creates a sampler.
+ *
+ * <p>The input parameters are not validated and must be correctly computed tables.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param n Number of categories
+ * @param k Number of levels in the discrete distribution generating (DDG) tree.
+ * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations.
+ * @param h Number of leaf nodes at each level.
+ * @param lH Stores the leaf node labels in increasing order.
+ */
+ FLDRSampler(UniformRandomProvider rng,
+ int n,
+ int k,
+ int[] h,
+ int[] lH) {
+ this.rng = rng;
+ this.n = n;
+ this.k = k;
+ // Deliberate direct storage of input arrays
+ this.h = h;
+ this.lH = lH;
+ }
+
+ /**
+ * Creates a copy with a new source of randomness.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param source Source to copy.
+ */
+ private FLDRSampler(UniformRandomProvider rng,
+ FLDRSampler source) {
+ this.rng = rng;
+ this.n = source.n;
+ this.k = source.k;
+ this.h = source.h;
+ this.lH = source.lH;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int sample() {
+ // ALGORITHM 5: SAMPLE
+ int c = 0;
+ int d = 0;
+ for (;;) {
+ // b = flip()
+ // d = 2 * d + (1 - b)
+ d = (d << 1) + flip();
+ if (d < h[c]) {
+ // z = H[d][c]
+ final int z = lH[d * k + c];
+ // assert z != NO_LABEL
+ if (z < n) {
+ return z;
+ }
+ d = 0;
+ c = 0;
+ } else {
+ d = d - h[c];
+ c++;
+ }
+ }
+ }
+
+ /**
+ * Provides a source of boolean bits.
+ *
+ * <p>Note: This replicates the boolean cache functionality of
+ * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return
+ * an {@code int} value rather than a {@code boolean}.
+ *
+ * @return the bit (0 or 1)
+ */
+ private int flip() {
+ int bits = booleanSource;
+ if (bits == 1) {
+ // Refill
+ bits = rng.nextInt();
+ // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit
+ booleanSource = Integer.MIN_VALUE | (bits >>> 1);
+ return bits & 0x1;
+ }
+ // Shift down eventually triggering refill, return current lowest bit
+ booleanSource = bits >>> 1;
+ return bits & 0x1;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return SAMPLER_NAME + " [" + rng.toString() + "]";
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+ return new FLDRSampler(rng, this);
+ }
+ }
+
+ /** Package-private constructor. */
+ FastLoadedDiceRollerDiscreteSampler() {
+ // Intentionally empty
+ }
+
+ /** {@inheritDoc} */
+ // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler
+ @Override
+ public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng);
+
+ /**
+ * Creates a sampler.
+ *
+ * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries
+ * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m}
+ * is the sum of the observed frequencies. An exception is raised if this cannot be allocated
+ * as a single array.
+ *
+ * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63.
+ * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042}
+ * when the sum of frequencies is large enough to create k=63.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param frequencies Observed frequencies of the discrete distribution.
+ * @return the sampler
+ * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
+ * frequency is negative, the sum of all frequencies is either zero or
+ * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree
+ * is too large.
+ */
+ public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
+ long[] frequencies) {
+ final long m = sum(frequencies);
+
+ // Obtain indices of non-zero frequencies
+ final int[] indices = indicesOfNonZero(frequencies);
+
+ // Edge case for 1 non-zero weight. This also handles edge case for 1 observation
+ // (as log2(m) == 0 will break the computation of the DDG tree).
+ if (indices.length == 1) {
+ return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
+ }
+
+ return createSampler(rng, frequencies, indices, m);
+ }
+
+ /**
+ * Creates a sampler.
+ *
+ * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2.
+ * The numerators {@code p} are scaled to use a common denominator before summing.
+ *
+ * <p>All weights are used to create the sampler. Weights with a small magnitude relative
+ * to the largest weight can be excluded using the constructor method with the
+ * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}).
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param weights Weights of the discrete distribution.
+ * @return the sampler
+ * @throws IllegalArgumentException if {@code weights} is null or empty, a
+ * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size
+ * of the discrete distribution generating tree is too large.
+ * @see #of(UniformRandomProvider, double[], int)
+ */
+ public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
+ double[] weights) {
+ return of(rng, weights, 0);
+ }
+
+ /**
+ * Creates a sampler.
+ *
+ * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is
+ * a power of 2. The numerators {@code p} are scaled to use a common
+ * denominator before summing.
+ *
+ * <p>Note: The discrete distribution generating (DDG) tree requires
+ * {@code (n + 1) * k} entries where {@code n} is the number of categories,
+ * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators
+ * {@code q}. An exception is raised if this cannot be allocated as a single
+ * array.
+ *
+ * <p>For reference the value {@code k} is equal to or greater than the ratio of
+ * the largest to the smallest weight expressed as a power of 2. For
+ * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value
+ * {@code k} increases with the sum of the weight numerators. A number of
+ * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE}
+ * would be required to raise an exception when the minimum weight is
+ * {@link Double#MIN_VALUE}.
+ *
+ * <p>Weights with a small magnitude relative to the largest weight can be
+ * excluded using the relative magnitude parameter {@code alpha}. This will set
+ * any weight to zero if the magnitude is approximately 2<sup>alpha</sup>
+ * <em>smaller</em> than the largest weight. This comparison is made using only
+ * the exponent of the input weights. The {@code alpha} parameter is ignored if
+ * not above zero. Note that a small {@code alpha} parameter will exclude more
+ * weights than a large {@code alpha} parameter.
+ *
+ * <p>The alpha parameter can be used to exclude categories that
+ * have a very low probability of occurrence and will improve the construction
+ * performance of the sampler. The effect on sampling performance depends on
+ * the relative weights of the excluded categories; typically a high {@code alpha}
+ * is used to exclude categories that would be visited with a very low probability
+ * and the sampling performance is unchanged.
+ *
+ * <p><b>Implementation Note</b>
+ *
+ * <p>This method creates a sampler with <em>exact</em> samples from the
+ * specified probability distribution. It is recommended to use this method:
+ * <ul>
+ * <li>if the weights are computed, for example from a probability mass function; or
+ * <li>if the weights sum to an infinite value.
+ * </ul>
+ *
+ * <p>If the weights are computed from empirical observations then it is
+ * recommended to use the factory method
+ * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
+ * requires the total number of observations to be representable as a long
+ * integer.
+ *
+ * <p>Note that if all weights are scaled by a power of 2 to be integers, and
+ * each integer can be represented as a positive 64-bit long value, then the
+ * sampler created using this method will match the output from a sampler
+ * created with the scaled weights converted to long values for the factory
+ * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
+ * assumes the sum of the integer values does not overflow.
+ *
+ * <p>It should be noted that the conversion of weights to rational numbers has
+ * a performance overhead during construction (sampling performance is not
+ * affected). This may be avoided by first converting them to integer values
+ * that can be summed without overflow. For example by scaling values by
+ * {@code 2^62 / sum} and converting to long by casting or rounding.
+ *
+ * <p>This approach may increase the efficiency of construction. The resulting
+ * sampler may no longer produce <em>exact</em> samples from the distribution.
+ * In particular any weights with a converted frequency of zero cannot be
+ * sampled.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param weights Weights of the discrete distribution.
+ * @param alpha Alpha parameter.
+ * @return the sampler
+ * @throws IllegalArgumentException if {@code weights} is null or empty, a
+ * weight is negative, infinite or {@code NaN}, the sum of all weights is zero,
+ * or the size of the discrete distribution generating tree is too large.
+ * @see #of(UniformRandomProvider, long[])
+ */
+ public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
+ double[] weights,
+ int alpha) {
+ final int n = checkWeightsNonZeroLength(weights);
+
+ // Convert floating-point double to a relative weight
+ // using a shifted integer representation
+ final long[] frequencies = new long[n];
+ final int[] offsets = new int[n];
+ convertToIntegers(weights, frequencies, offsets, alpha);
+
+ // Obtain indices of non-zero weights
+ final int[] indices = indicesOfNonZero(frequencies);
+
+ // Edge case for 1 non-zero weight.
+ if (indices.length == 1) {
+ return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
+ }
+
+ final BigInteger m = sum(frequencies, offsets, indices);
+
+ // Use long arithmetic if possible. This occurs when the weights are similar in magnitude.
+ if (m.compareTo(MAX_LONG) <= 0) {
+ // Apply the offset
+ for (int i = 0; i < n; i++) {
+ frequencies[i] <<= offsets[i];
+ }
+ return createSampler(rng, frequencies, indices, m.longValue());
+ }
+
+ return createSampler(rng, frequencies, offsets, indices, m);
+ }
+
+ /**
+ * Sum the frequencies.
+ *
+ * @param frequencies Frequencies.
+ * @return the sum
+ * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
+ * frequency is negative, or the sum of all frequencies is either zero or above
+ * {@link Long#MAX_VALUE}
+ */
+ private static long sum(long[] frequencies) {
+ // Validate
+ if (frequencies == null || frequencies.length == 0) {
+ throw new IllegalArgumentException("frequencies must contain at least 1 value");
+ }
+
+ // Sum the values.
+ // Combine all the sign bits in the observations and the intermediate sum in a flag.
+ long m = 0;
+ long signFlag = 0;
+ for (final long o : frequencies) {
+ m += o;
+ signFlag |= o | m;
+ }
+
+ // Check for a sign-bit.
+ if (signFlag < 0) {
+ // One or more observations were negative, or the sum overflowed.
+ for (final long o : frequencies) {
+ if (o < 0) {
+ throw new IllegalArgumentException("frequencies must contain positive values: " + o);
+ }
+ }
+ throw new IllegalArgumentException("Overflow when summing frequencies");
+ }
+ if (m == 0) {
+ throw new IllegalArgumentException("Sum of frequencies is zero");
+ }
+ return m;
+ }
+
+ /**
+ * Convert the floating-point weights to relative weights represented as
+ * integers {@code value * 2^exponent}. The relative weight as an integer is:
+ *
+ * <pre>
+ * BigInteger.valueOf(value).shiftLeft(exponent)
+ * </pre>
+ *
+ * <p>Note that the weights are created using a common power-of-2 scaling
+ * operation so the minimum exponent is zero.
+ *
+ * <p>A positive {@code alpha} parameter is used to set any weight to zero if
+ * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
+ * largest weight. This comparison is made using only the exponent of the input
+ * weights.
+ *
+ * @param weights Weights of the discrete distribution.
+ * @param values Output floating-point mantissas converted to 53-bit integers.
+ * @param exponents Output power of 2 exponent.
+ * @param alpha Alpha parameter.
+ * @throws IllegalArgumentException if a weight is negative, infinite or
+ * {@code NaN}, or the sum of all weights is zero.
+ */
+ private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) {
+ int maxExponent = Integer.MIN_VALUE;
+ for (int i = 0; i < weights.length; i++) {
+ final double weight = weights[i];
+ // Ignore zero.
+ // When creating the integer value later using bit shifts the result will remain zero.
+ if (weight == 0) {
+ continue;
+ }
+ final long bits = Double.doubleToRawLongBits(weight);
+
+ // For the IEEE 754 format see Double.longBitsToDouble(long).
+
+ // Extract the exponent (with the sign bit)
+ int exp = (int) (bits >>> MANTISSA_SIZE);
+ // Detect negative, infinite or NaN.
+ // Note: Negative values sign bit will cause the exponent to be too high.
+ if (exp > MAX_BIASED_EXPONENT) {
+ throw new IllegalArgumentException("Invalid weight: " + weight);
+ }
+ long mantissa;
+ if (exp == 0) {
+ // Sub-normal number:
+ mantissa = (bits & MANTISSA_MASK) << 1;
+ // Here we convert to a normalised number by counting the leading zeros
+ // to obtain the number of shifts of the most significant bit in
+ // the mantissa that is required to get a 1 at position 53 (i.e. as
+ // if it were a normal number with assumed leading bit).
+ final int shift = Long.numberOfLeadingZeros(mantissa << 11);
+ mantissa <<= shift;
+ exp -= shift;
+ } else {
+ // Normal number. Add the implicit leading 1-bit.
+ mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE);
+ }
+
+ // Here the floating-point number is equal to:
+ // mantissa * 2^(exp-1075)
+
+ values[i] = mantissa;
+ exponents[i] = exp;
+ maxExponent = Math.max(maxExponent, exp);
+ }
+
+ // No exponent indicates that all weights are zero
+ if (maxExponent == Integer.MIN_VALUE) {
+ throw new IllegalArgumentException("Sum of weights is zero");
+ }
+
+ filterWeights(values, exponents, alpha, maxExponent);
+ scaleWeights(values, exponents);
+ }
+
+ /**
+ * Filters small weights using the {@code alpha} parameter.
+ * A positive {@code alpha} parameter is used to set any weight to zero if
+ * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
+ * largest weight. This comparison is made using only the exponent of the input
+ * weights.
+ *
+ * @param values 53-bit values.
+ * @param exponents Power of 2 exponent.
+ * @param alpha Alpha parameter.
+ * @param maxExponent Maximum exponent.
+ */
+ private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) {
+ if (alpha > 0) {
+ // Filter weights. This must be done before the values are shifted so
+ // the exponent represents the approximate magnitude of the value.
+ for (int i = 0; i < exponents.length; i++) {
+ if (maxExponent - exponents[i] > alpha) {
+ values[i] = 0;
+ }
+ }
+ }
+ }
+
+ /**
+ * Scale the weights represented as integers {@code value * 2^exponent} to use a
+ * minimum exponent of zero. The values are scaled to remove any common trailing zeros
+ * in their representation. This ultimately reduces the size of the discrete distribution
+ * generating (DGG) tree.
+ *
+ * @param values 53-bit values.
+ * @param exponents Power of 2 exponent.
+ */
+ private static void scaleWeights(long[] values, int[] exponents) {
+ // Find the minimum exponent and common trailing zeros.
+ int minExponent = Integer.MAX_VALUE;
+ for (int i = 0; i < exponents.length; i++) {
+ if (values[i] != 0) {
+ minExponent = Math.min(minExponent, exponents[i]);
+ }
+ }
+ // Trailing zeros occur when the original weights have a representation with
+ // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}.
+ int trailingZeros = Long.SIZE;
+ for (int i = 0; i < values.length && trailingZeros != 0; i++) {
+ trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i]));
+ }
+ // Scale by a power of 2 so the minimum exponent is zero.
+ for (int i = 0; i < exponents.length; i++) {
+ exponents[i] -= minExponent;
+ }
+ // Remove common trailing zeros.
+ if (trailingZeros != 0) {
+ for (int i = 0; i < values.length; i++) {
+ values[i] >>>= trailingZeros;
+ }
+ }
+ }
+
+ /**
+ * Sum the integers at the specified indices.
+ * Integers are represented as {@code value * 2^exponent}.
+ *
+ * @param values 53-bit values.
+ * @param exponents Power of 2 exponent.
+ * @param indices Indices to sum.
+ * @return the sum
+ */
+ private static BigInteger sum(long[] values, int[] exponents, int[] indices) {
+ BigInteger m = BigInteger.ZERO;
+ for (final int i : indices) {
+ m = m.add(toBigInteger(values[i], exponents[i]));
+ }
+ return m;
+ }
+
+ /**
+ * Convert the value and left shift offset to a BigInteger.
+ * It is assumed the value is at most 53-bits. This allows optimising the left
+ * shift if it is below 11 bits.
+ *
+ * @param value 53-bit value.
+ * @param offset Left shift offset (must be positive).
+ * @return the BigInteger
+ */
+ private static BigInteger toBigInteger(long value, int offset) {
+ // Ignore zeros. The sum method uses indices of non-zero values.
+ if (offset <= MAX_OFFSET) {
+ // Assume (value << offset) <= Long.MAX_VALUE
+ return BigInteger.valueOf(value << offset);
+ }
+ return BigInteger.valueOf(value).shiftLeft(offset);
+ }
+
+ /**
+ * Creates the sampler.
+ *
+ * <p>It is assumed the frequencies are all positive and the sum does not
+ * overflow.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param frequencies Observed frequencies of the discrete distribution.
+ * @param indices Indices of non-zero frequencies.
+ * @param m Sum of the frequencies.
+ * @return the sampler
+ */
+ private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
+ long[] frequencies,
+ int[] indices,
+ long m) {
+ // ALGORITHM 5: PREPROCESS
+ // a == frequencies
+ // m = sum(a)
+ // h = leaf node count
+ // H = leaf node label (lH)
+
+ final int n = frequencies.length;
+
+ // k = ceil(log2(m))
+ final int k = 64 - Long.numberOfLeadingZeros(m - 1);
+ // r = a(n+1) = 2^k - m
+ final long r = (1L << k) - m;
+
+ // Note:
+ // A sparse matrix can often be used for H, as most of its entries are empty.
+ // This implementation uses a 1D array for efficiency at the cost of memory.
+ // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of
+ // observations is large enough to create k=63.
+ // This could be handled using a 2D array. In practice a number of categories this
+ // large is not expected and is currently not supported.
+ final int[] h = new int[k];
+ final int[] lH = new int[checkArraySize((n + 1L) * k)];
+ Arrays.fill(lH, NO_LABEL);
+
+ int d;
+ for (int j = 0; j < k; j++) {
+ final int shift = (k - 1) - j;
+ final long bitMask = 1L << shift;
+
+ d = 0;
+ for (final int i : indices) {
+ // bool w ← (a[i] >> (k − 1) − j)) & 1
+ // h[j] = h[j] + w
+ // if w then:
+ if ((frequencies[i] & bitMask) != 0) {
+ h[j]++;
+ // H[d][j] = i
+ lH[d * k + j] = i;
+ d++;
+ }
+ }
+ // process a(n+1) without extending the input frequencies array by 1
+ if ((r & bitMask) != 0) {
+ h[j]++;
+ lH[d * k + j] = n;
+ }
+ }
+
+ return new FLDRSampler(rng, n, k, h, lH);
+ }
+
+ /**
+ * Creates the sampler. Frequencies are are represented as a 53-bit value with a
+ * left-shift offset.
+ * <pre>
+ * BigInteger.valueOf(value).shiftLeft(offset)
+ * </pre>
+ *
+ * <p>It is assumed the frequencies are all positive.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param frequencies Observed frequencies of the discrete distribution.
+ * @param offsets Left shift offsets (must be positive).
+ * @param indices Indices of non-zero frequencies.
+ * @param m Sum of the frequencies.
+ * @return the sampler
+ */
+ private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
+ long[] frequencies,
+ int[] offsets,
+ int[] indices,
+ BigInteger m) {
+ // Repeat the logic from createSampler(...) using extended arithmetic to test the bits
+
+ // ALGORITHM 5: PREPROCESS
+ // a == frequencies
+ // m = sum(a)
+ // h = leaf node count
+ // H = leaf node label (lH)
+
+ final int n = frequencies.length;
+
+ // k = ceil(log2(m))
+ final int k = m.subtract(BigInteger.ONE).bitLength();
+ // r = a(n+1) = 2^k - m
+ final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m);
+
+ final int[] h = new int[k];
+ final int[] lH = new int[checkArraySize((n + 1L) * k)];
+ Arrays.fill(lH, NO_LABEL);
+
+ int d;
+ for (int j = 0; j < k; j++) {
+ final int shift = (k - 1) - j;
+
+ d = 0;
+ for (final int i : indices) {
+ // bool w ← (a[i] >> (k − 1) − j)) & 1
+ // h[j] = h[j] + w
+ // if w then:
+ if (testBit(frequencies[i], offsets[i], shift)) {
+ h[j]++;
+ // H[d][j] = i
+ lH[d * k + j] = i;
+ d++;
+ }
+ }
+ // process a(n+1) without extending the input frequencies array by 1
+ if (r.testBit(shift)) {
+ h[j]++;
+ lH[d * k + j] = n;
+ }
+ }
+
+ return new FLDRSampler(rng, n, k, h, lH);
+ }
+
+ /**
+ * Test the logical bit of the shifted integer representation.
+ * The value is assumed to have at most 53-bits of information. The offset
+ * is assumed to be positive. This is functionally equivalent to:
+ * <pre>
+ * BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
+ * </pre>
+ *
+ * @param value 53-bit value.
+ * @param offset Left shift offset.
+ * @param n Index of bit to test.
+ * @return true if the bit is 1
+ */
+ private static boolean testBit(long value, int offset, int n) {
+ if (n < offset) {
+ // All logical trailing bits are zero
+ return false;
+ }
+ // Test if outside the 53-bit value (note that the implicit 1 bit
+ // has been added to the 52-bit mantissas for 'normal' floating-point numbers).
+ final int bit = n - offset;
+ return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0;
+ }
+
+ /**
+ * Check the weights have a non-zero length.
+ *
+ * @param weights Weights.
+ * @return the length
+ */
+ private static int checkWeightsNonZeroLength(double[] weights) {
+ if (weights == null || weights.length == 0) {
+ throw new IllegalArgumentException("weights must contain at least 1 value");
+ }
+ return weights.length;
+ }
+
+ /**
+ * Create the indices of non-zero values.
+ *
+ * @param values Values.
+ * @return the indices
+ */
+ private static int[] indicesOfNonZero(long[] values) {
+ int n = 0;
+ final int[] indices = new int[values.length];
+ for (int i = 0; i < values.length; i++) {
+ if (values[i] != 0) {
+ indices[n++] = i;
+ }
+ }
+ return Arrays.copyOf(indices, n);
+ }
+
+ /**
+ * Find the index of the first non-zero frequency.
+ *
+ * @param frequencies Frequencies.
+ * @return the index
+ * @throws IllegalStateException if all frequencies are zero.
+ */
+ static int indexOfNonZero(long[] frequencies) {
+ for (int i = 0; i < frequencies.length; i++) {
+ if (frequencies[i] != 0) {
+ return i;
+ }
+ }
+ throw new IllegalStateException("All frequencies are zero");
+ }
+
+ /**
+ * Check the size is valid for a 1D array.
+ *
+ * @param size Size
+ * @return the size as an {@code int}
+ * @throws IllegalArgumentException if the size is too large for a 1D array.
+ */
+ static int checkArraySize(long size) {
+ if (size > MAX_ARRAY_SIZE) {
+ throw new IllegalArgumentException("Unable to allocate array of size: " + size);
+ }
+ return (int) size;
+ }
+}
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
index 81218db..1a33921 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -181,12 +181,17 @@
// Any discrete distribution
final int[] discretePoints = {0, 1, 2, 3, 4};
final double[] discreteProbabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
+ final long[] discreteFrequencies = {1, 2, 3, 4, 5};
add(LIST, discretePoints, discreteProbabilities,
MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomSource.XO_SHI_RO_512_PLUS.create(), discreteProbabilities));
add(LIST, discretePoints, discreteProbabilities,
GuideTableDiscreteSampler.of(RandomSource.XO_SHI_RO_512_SS.create(), discreteProbabilities));
add(LIST, discretePoints, discreteProbabilities,
AliasMethodDiscreteSampler.of(RandomSource.KISS.create(), discreteProbabilities));
+ add(LIST, discretePoints, discreteProbabilities,
+ FastLoadedDiceRollerDiscreteSampler.of(RandomSource.L64_X128_MIX.create(), discreteFrequencies));
+ add(LIST, discretePoints, discreteProbabilities,
+ FastLoadedDiceRollerDiscreteSampler.of(RandomSource.L64_X128_SS.create(), discreteProbabilities));
} catch (Exception e) {
// CHECKSTYLE: stop Regexp
System.err.println("Unexpected exception while creating the list of samplers: " + e);
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSamplerTest.java
new file mode 100644
index 0000000..27350dc
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSamplerTest.java
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import java.util.Arrays;
+import java.util.function.DoubleUnaryOperator;
+import java.util.stream.Stream;
+import org.apache.commons.math3.stat.descriptive.moment.Mean;
+import org.apache.commons.math3.stat.inference.ChiSquareTest;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.RandomAssert;
+import org.apache.commons.rng.simple.RandomSource;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Assumptions;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.ValueSource;
+
+/**
+ * Test for the {@link FastLoadedDiceRollerDiscreteSampler}.
+ */
+class FastLoadedDiceRollerDiscreteSamplerTest {
+ /**
+ * Creates the sampler.
+ *
+ * @param frequencies Observed frequencies.
+ * @return the FLDR sampler
+ */
+ private static SharedStateDiscreteSampler createSampler(long... frequencies) {
+ final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
+ return FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
+ }
+
+ /**
+ * Creates the sampler.
+ *
+ * @param weights Weights.
+ * @return the FLDR sampler
+ */
+ private static SharedStateDiscreteSampler createSampler(double... weights) {
+ final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
+ return FastLoadedDiceRollerDiscreteSampler.of(rng, weights);
+ }
+
+ /**
+ * Return a stream of invalid frequencies for a discrete distribution.
+ *
+ * @return the stream of invalid frequencies
+ */
+ static Stream<long[]> testFactoryConstructorFrequencies() {
+ return Stream.of(
+ // Null or empty
+ (long[]) null,
+ new long[0],
+ // Negative
+ new long[] {-1, 2, 3},
+ new long[] {1, -2, 3},
+ new long[] {1, 2, -3},
+ // Overflow of sum
+ new long[] {Long.MAX_VALUE, Long.MAX_VALUE},
+ // x+x+2 == 0
+ new long[] {Long.MAX_VALUE, Long.MAX_VALUE, 2},
+ // x+x+x == x - 2 (i.e. positive)
+ new long[] {Long.MAX_VALUE, Long.MAX_VALUE, Long.MAX_VALUE},
+ // Zero sum
+ new long[1],
+ new long[4]
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testFactoryConstructorFrequencies(long[] frequencies) {
+ Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(frequencies));
+ }
+
+ /**
+ * Return a stream of invalid weights for a discrete distribution.
+ *
+ * @return the stream of invalid weights
+ */
+ static Stream<double[]> testFactoryConstructorWeights() {
+ return Stream.of(
+ // Null or empty
+ (double[]) null,
+ new double[0],
+ // Negative, infinite or NaN
+ new double[] {-1, 2, 3},
+ new double[] {1, -2, 3},
+ new double[] {1, 2, -3},
+ new double[] {Double.POSITIVE_INFINITY, 2, 3},
+ new double[] {1, Double.POSITIVE_INFINITY, 3},
+ new double[] {1, 2, Double.POSITIVE_INFINITY},
+ new double[] {Double.NaN, 2, 3},
+ new double[] {1, Double.NaN, 3},
+ new double[] {1, 2, Double.NaN},
+ // Zero sum
+ new double[1],
+ new double[4]
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testFactoryConstructorWeights(double[] weights) {
+ Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(weights));
+ }
+
+ @Test
+ void testToString() {
+ for (final long[] observed : new long[][] {{42}, {1, 2, 3}}) {
+ final SharedStateDiscreteSampler sampler = createSampler(observed);
+ Assertions.assertTrue(sampler.toString().toLowerCase().contains("fast loaded dice roller"));
+ }
+ }
+
+ @Test
+ void testSingleCategory() {
+ final int n = 13;
+ final int[] expected = new int[n];
+ Assertions.assertArrayEquals(expected, createSampler(42).samples(n).toArray());
+ Assertions.assertArrayEquals(expected, createSampler(0.55).samples(n).toArray());
+ }
+
+ @Test
+ void testSingleFrequency() {
+ final long[] frequencies = new long[5];
+ final int category = 2;
+ frequencies[category] = 1;
+ final SharedStateDiscreteSampler sampler = createSampler(frequencies);
+ final int n = 7;
+ final int[] expected = new int[n];
+ Arrays.fill(expected, category);
+ Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
+ }
+
+ @Test
+ void testSingleWeight() {
+ final double[] weights = new double[5];
+ final int category = 3;
+ weights[category] = 1.5;
+ final SharedStateDiscreteSampler sampler = createSampler(weights);
+ final int n = 6;
+ final int[] expected = new int[n];
+ Arrays.fill(expected, category);
+ Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
+ }
+
+ @Test
+ void testIndexOfNonZero() {
+ Assertions.assertThrows(IllegalStateException.class,
+ () -> FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(new long[3]));
+ final long[] data = new long[3];
+ for (int i = 0; i < data.length; i++) {
+ data[i] = 13;
+ Assertions.assertEquals(i, FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(data));
+ data[i] = 0;
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource(longs = {0, 1, -1, Integer.MAX_VALUE, 1L << 34})
+ void testCheckArraySize(long size) {
+ // This is the same value as the sampler
+ final int max = Integer.MAX_VALUE - 8;
+ // Note: The method does not test for negatives.
+ // This is not required when validating a positive int multiplied by another positive int.
+ if (size > max) {
+ Assertions.assertThrows(IllegalArgumentException.class,
+ () -> FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
+ } else {
+ Assertions.assertEquals((int) size, FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
+ }
+ }
+
+ /**
+ * Return a stream of expected frequencies for a discrete distribution.
+ *
+ * @return the stream of expected frequencies
+ */
+ static Stream<long[]> testSamplesFrequencies() {
+ return Stream.of(
+ // Single category
+ new long[] {0, 0, 42, 0, 0},
+ // Sum to a power of 2
+ new long[] {1, 1, 2, 3, 1},
+ new long[] {0, 1, 1, 0, 2, 3, 1, 0},
+ // Do not sum to a power of 2
+ new long[] {1, 2, 3, 1, 3},
+ new long[] {1, 0, 2, 0, 3, 1, 3},
+ // Large frequencies
+ new long[] {5126734627834L, 213267384684832L, 126781236718L, 71289979621378L}
+ );
+ }
+
+ /**
+ * Check the distribution of samples match the expected probabilities.
+ *
+ * @param expectedFrequencies Expected frequencies.
+ */
+ @ParameterizedTest
+ @MethodSource
+ void testSamplesFrequencies(long[] expectedFrequencies) {
+ final SharedStateDiscreteSampler sampler = createSampler(expectedFrequencies);
+ final int numberOfSamples = 10000;
+ final long[] samples = new long[expectedFrequencies.length];
+ sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
+
+ // Handle a test with some zero-probability observations by mapping them out
+ int mapSize = 0;
+ double sum = 0;
+ for (final double f : expectedFrequencies) {
+ if (f != 0) {
+ mapSize++;
+ sum += f;
+ }
+ }
+
+ // Single category will break the Chi-square test
+ if (mapSize == 1) {
+ int index = 0;
+ while (index < expectedFrequencies.length) {
+ if (expectedFrequencies[index] != 0) {
+ break;
+ }
+ index++;
+ }
+ Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
+ return;
+ }
+
+ final double[] expected = new double[mapSize];
+ final long[] observed = new long[mapSize];
+ for (int i = 0; i < expectedFrequencies.length; i++) {
+ if (expectedFrequencies[i] != 0) {
+ --mapSize;
+ expected[mapSize] = expectedFrequencies[i] / sum;
+ observed[mapSize] = samples[i];
+ } else {
+ Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
+ }
+ }
+
+ final ChiSquareTest chiSquareTest = new ChiSquareTest();
+ // Pass if we cannot reject null hypothesis that the distributions are the same.
+ Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
+ }
+
+ /**
+ * Return a stream of expected weights for a discrete distribution.
+ *
+ * @return the stream of expected weights
+ */
+ static Stream<double[]> testSamplesWeights() {
+ return Stream.of(
+ // Single category
+ new double[] {0, 0, 0.523, 0, 0},
+ // Sum to a power of 2
+ new double[] {0.125, 0.125, 0.25, 0.375, 0.125},
+ new double[] {0, 0.125, 0.125, 0.25, 0, 0.375, 0.125, 0},
+ // Do not sum to a power of 2
+ new double[] {0.1, 0.2, 0.3, 0.1, 0.3},
+ new double[] {0.1, 0, 0.2, 0, 0.3, 0.1, 0.3},
+ // Sub-normal numbers
+ new double[] {5 * Double.MIN_NORMAL, 2 * Double.MIN_NORMAL, 3 * Double.MIN_NORMAL, 9 * Double.MIN_NORMAL},
+ new double[] {2 * Double.MIN_NORMAL, Double.MIN_NORMAL, 0.5 * Double.MIN_NORMAL, 0.75 * Double.MIN_NORMAL},
+ new double[] {Double.MIN_VALUE, 2 * Double.MIN_VALUE, 3 * Double.MIN_VALUE, 7 * Double.MIN_VALUE},
+ // Large range of magnitude
+ new double[] {1.0, 2.0, Math.scalb(3.0, -32), Math.scalb(4.0, -65), 5.0},
+ new double[] {Math.scalb(1.0, 35), Math.scalb(2.0, 35), Math.scalb(3.0, -32), Math.scalb(4.0, -65), Math.scalb(5.0, 35)},
+ // Sum to infinite
+ new double[] {Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE / 2, Double.MAX_VALUE / 4}
+ );
+ }
+
+ /**
+ * Check the distribution of samples match the expected weights.
+ *
+ * @param weights Category weights.
+ */
+ @ParameterizedTest
+ @MethodSource
+ void testSamplesWeights(double[] weights) {
+ final SharedStateDiscreteSampler sampler = createSampler(weights);
+ final int numberOfSamples = 10000;
+ final long[] samples = new long[weights.length];
+ sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
+
+ // Handle a test with some zero-probability observations by mapping them out
+ int mapSize = 0;
+ double sum = 0;
+ // Handle infinite sum using a rolling mean for normalisation
+ final Mean mean = new Mean();
+ for (final double w : weights) {
+ if (w != 0) {
+ mapSize++;
+ sum += w;
+ mean.increment(w);
+ }
+ }
+
+ // Single category will break the Chi-square test
+ if (mapSize == 1) {
+ int index = 0;
+ while (index < weights.length) {
+ if (weights[index] != 0) {
+ break;
+ }
+ index++;
+ }
+ Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
+ return;
+ }
+
+ final double mu = mean.getResult();
+ final int n = mapSize;
+ final double s = sum;
+ final DoubleUnaryOperator normalise = Double.isInfinite(sum) ?
+ x -> (x / mu) * n :
+ x -> x / s;
+
+ final double[] expected = new double[mapSize];
+ final long[] observed = new long[mapSize];
+ for (int i = 0; i < weights.length; i++) {
+ if (weights[i] != 0) {
+ --mapSize;
+ expected[mapSize] = normalise.applyAsDouble(weights[i]);
+ observed[mapSize] = samples[i];
+ } else {
+ Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
+ }
+ }
+
+ final ChiSquareTest chiSquareTest = new ChiSquareTest();
+ // Pass if we cannot reject null hypothesis that the distributions are the same.
+ Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
+ }
+
+ /**
+ * Check the distribution of samples when the frequencies can be converted to weights without
+ * loss of precision.
+ *
+ * @param frequencies Expected frequencies.
+ */
+ @ParameterizedTest
+ @MethodSource(value = {"testSamplesFrequencies"})
+ void testSamplesWeightsMatchesFrequencies(long[] frequencies) {
+ final double[] weights = new double[frequencies.length];
+ for (int i = 0; i < frequencies.length; i++) {
+ final double w = frequencies[i];
+ Assumptions.assumeTrue((long) w == frequencies[i]);
+ // Ensure the exponent is set in the event of simple frequencies
+ weights[i] = Math.scalb(w, -35);
+ }
+ final long seed = RandomSource.createLong();
+ final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
+ final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
+ final SharedStateDiscreteSampler sampler1 =
+ FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
+ final SharedStateDiscreteSampler sampler2 =
+ FastLoadedDiceRollerDiscreteSampler.of(rng2, weights);
+ RandomAssert.assertProduceSameSequence(sampler1, sampler2);
+ }
+
+ /**
+ * Test scaled weights. The sampler uses the relative magnitude of weights and the
+ * output should be invariant to scaling. The weights are sampled from the 2^53 dyadic
+ * rationals in [0, 1). A scale factor of -1021 is the lower limit if a weight is
+ * 2^-53 to maintain a non-zero weight. The upper limit is 1023 if a weight is 1 to avoid
+ * infinite values. Note that it does not matter if the sum of weights is infinite; only
+ * the individual weights must be finite.
+ *
+ * @param scaleFactor the scale factor
+ */
+ @ParameterizedTest
+ @ValueSource(ints = {1023, 67, 1, -59, -1020, -1021})
+ void testScaledWeights(int scaleFactor) {
+ // Weights in [0, 1)
+ final double[] w1 = RandomSource.KISS.create().doubles(10).toArray();
+ final double scale = Math.scalb(1.0, scaleFactor);
+ final double[] w2 = Arrays.stream(w1).map(x -> x * scale).toArray();
+ final long seed = RandomSource.createLong();
+ final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
+ final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
+ final SharedStateDiscreteSampler sampler1 =
+ FastLoadedDiceRollerDiscreteSampler.of(rng1, w1);
+ final SharedStateDiscreteSampler sampler2 =
+ FastLoadedDiceRollerDiscreteSampler.of(rng2, w2);
+ RandomAssert.assertProduceSameSequence(sampler1, sampler2);
+ }
+
+ /**
+ * Test the alpha parameter removes small relative weights.
+ * Weights should be removed if they are {@code 2^alpha} smaller than the largest
+ * weight.
+ *
+ * @param alpha Alpha parameter
+ */
+ @ParameterizedTest
+ @ValueSource(ints = {13, 30, 53})
+ void testAlphaRemovesWeights(int alpha) {
+ // The small weight must be > 2^alpha smaller so scale by (alpha + 1)
+ final double small = Math.scalb(1.0, -(alpha + 1));
+ final double[] w1 = {1, 0.5, 0.5, 0};
+ final double[] w2 = {1, 0.5, 0.5, small};
+ final long seed = RandomSource.createLong();
+ final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
+ final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
+ final UniformRandomProvider rng3 = RandomSource.SPLIT_MIX_64.create(seed);
+
+ final int n = 10;
+ final int[] s1 = FastLoadedDiceRollerDiscreteSampler.of(rng1, w1).samples(n).toArray();
+ final int[] s2 = FastLoadedDiceRollerDiscreteSampler.of(rng2, w2, alpha).samples(n).toArray();
+ final int[] s3 = FastLoadedDiceRollerDiscreteSampler.of(rng3, w2, alpha + 1).samples(n).toArray();
+
+ Assertions.assertArrayEquals(s1, s2, "alpha parameter should ignore the small weight");
+ Assertions.assertFalse(Arrays.equals(s1, s3), "alpha+1 parameter should not ignore the small weight");
+ }
+
+ static Stream<long[]> testSharedStateSampler() {
+ return Stream.of(
+ new long[] {42},
+ new long[] {1, 1, 2, 3, 1}
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testSharedStateSampler(long[] frequencies) {
+ final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L);
+ final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L);
+ final SharedStateDiscreteSampler sampler1 =
+ FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
+ final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
+ RandomAssert.assertProduceSameSequence(sampler1, sampler2);
+ }
+}
diff --git a/src/main/resources/pmd/pmd-ruleset.xml b/src/main/resources/pmd/pmd-ruleset.xml
index 288690f..1d4f0ec 100644
--- a/src/main/resources/pmd/pmd-ruleset.xml
+++ b/src/main/resources/pmd/pmd-ruleset.xml
@@ -77,7 +77,7 @@
<property name="violationSuppressXPath"
value="//ClassOrInterfaceDeclaration[@SimpleName='PoissonSamplerCache' or @SimpleName='AliasMethodDiscreteSampler'
or @SimpleName='GuideTableDiscreteSampler' or @SimpleName='SharedStateDiscreteProbabilitySampler'
- or @SimpleName='DirichletSampler']"/>
+ or @SimpleName='DirichletSampler' or @SimpleName='FastLoadedDiceRollerDiscreteSampler']"/>
</properties>
</rule>
<rule ref="category/java/bestpractices.xml/SystemPrintln">
@@ -144,6 +144,14 @@
<property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[matches(@SimpleName, '^.*Builder$')]"/>
</properties>
</rule>
+ <rule ref="category/java/codestyle.xml/PrematureDeclaration">
+ <properties>
+ <!-- False positive where minExponent is stored before a possible exit point. -->
+ <property name="violationSuppressXPath"
+ value="./ancestor::ClassOrInterfaceDeclaration[@SimpleName='FastLoadedDiceRollerDiscreteSampler'] and
+ ./ancestor::MethodName[@Image='of']"/>
+ </properties>
+ </rule>
<rule ref="category/java/design.xml/NPathComplexity">
<properties>
@@ -229,6 +237,12 @@
value="../MethodDeclaration[@Name='jump' or @Name='longJump']"/>
</properties>
</rule>
+ <rule ref="category/java/design.xml/GodClass">
+ <properties>
+ <property name="violationSuppressXPath"
+ value="./ancestor-or-self::ClassOrInterfaceDeclaration[@SimpleName='FastLoadedDiceRollerDiscreteSampler']"/>
+ </properties>
+ </rule>
<rule ref="category/java/errorprone.xml/AvoidLiteralsInIfCondition">
<properties>