| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.commons.rng.sampling.distribution; |
| |
| import java.math.BigInteger; |
| import java.util.Arrays; |
| import org.apache.commons.rng.UniformRandomProvider; |
| |
| /** |
| * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to |
| * sample from {@code n} values each with an associated relative weight. If all unique items |
| * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}. |
| * |
| * <p>Given a list {@code L} of {@code n} positive numbers, |
| * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns |
| * integer {@code i} with relative probability {@code L[i]}. |
| * |
| * <p>FLDR produces <em>exact</em> samples from the specified probability distribution. |
| * <ul> |
| * <li>For integer weights, the probability of returning {@code i} is precisely equal to the |
| * rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}. |
| * <li>For floating-points weights, each weight {@code L[i]} is converted to the |
| * corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and |
| * {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity. |
| * </ul> |
| * |
| * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that |
| * ignores very small relative weights may have improved sampling performance. |
| * |
| * <p>This implementation is based on the algorithm in: |
| * |
| * <blockquote> |
| * Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka. |
| * The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability |
| * Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on |
| * Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108, |
| * Palermo, Sicily, Italy, 2020. |
| * </blockquote> |
| * |
| * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits. |
| * |
| * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020) |
| * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics, |
| * PMLR 108:1036-1046.</a> |
| * @since 1.5 |
| */ |
| public abstract class FastLoadedDiceRollerDiscreteSampler |
| implements SharedStateDiscreteSampler { |
| /** |
| * The maximum size of an array. |
| * |
| * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}. |
| * It allows VMs to reserve some header words in an array. |
| */ |
| private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; |
| /** The maximum biased exponent for a finite double. |
| * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */ |
| private static final int MAX_BIASED_EXPONENT = 2046; |
| /** Size of the mantissa of a double. Equal to 52 bits. */ |
| private static final int MANTISSA_SIZE = 52; |
| /** Mask to extract the 52-bit mantissa from a long representation of a double. */ |
| private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL; |
| /** BigInteger representation of {@link Long#MAX_VALUE}. */ |
| private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE); |
| /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value. |
| * The value will remain positive for any shift {@code <=} this value. */ |
| private static final int MAX_OFFSET = 10; |
| /** Initial value for no leaf node label. */ |
| private static final int NO_LABEL = Integer.MAX_VALUE; |
| /** Name of the sampler. */ |
| private static final String SAMPLER_NAME = "Fast Loaded Dice Roller"; |
| |
| /** |
| * Class to handle the edge case of observations in only one category. |
| */ |
| private static class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler { |
| /** The sample value. */ |
| private final int sampleValue; |
| |
| /** |
| * @param sampleValue Sample value. |
| */ |
| FixedValueDiscreteSampler(int sampleValue) { |
| this.sampleValue = sampleValue; |
| } |
| |
| @Override |
| public int sample() { |
| return sampleValue; |
| } |
| |
| @Override |
| public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { |
| return this; |
| } |
| |
| @Override |
| public String toString() { |
| return SAMPLER_NAME; |
| } |
| } |
| |
| /** |
| * Class to implement the FLDR sample algorithm. |
| */ |
| private static class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler { |
| /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on |
| * the boolean source. */ |
| private static final int EMPTY_BOOL_SOURCE = 1; |
| |
| /** Underlying source of randomness. */ |
| private final UniformRandomProvider rng; |
| /** Number of categories. */ |
| private final int n; |
| /** Number of levels in the discrete distribution generating (DDG) tree. |
| * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */ |
| private final int k; |
| /** Number of leaf nodes at each level. */ |
| private final int[] h; |
| /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */ |
| private final int[] lH; |
| |
| /** |
| * Provides a bit source for booleans. |
| * |
| * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}. |
| * |
| * <p>Only stores 31-bits when full as 1 bit has already been consumed. |
| * The sign bit is a flag that shifts down so the source eventually equals 1 |
| * when all bits are consumed and will trigger a refill. |
| */ |
| private int booleanSource = EMPTY_BOOL_SOURCE; |
| |
| /** |
| * Creates a sampler. |
| * |
| * <p>The input parameters are not validated and must be correctly computed tables. |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param n Number of categories |
| * @param k Number of levels in the discrete distribution generating (DDG) tree. |
| * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. |
| * @param h Number of leaf nodes at each level. |
| * @param lH Stores the leaf node labels in increasing order. |
| */ |
| FLDRSampler(UniformRandomProvider rng, |
| int n, |
| int k, |
| int[] h, |
| int[] lH) { |
| this.rng = rng; |
| this.n = n; |
| this.k = k; |
| // Deliberate direct storage of input arrays |
| this.h = h; |
| this.lH = lH; |
| } |
| |
| /** |
| * Creates a copy with a new source of randomness. |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param source Source to copy. |
| */ |
| private FLDRSampler(UniformRandomProvider rng, |
| FLDRSampler source) { |
| this.rng = rng; |
| this.n = source.n; |
| this.k = source.k; |
| this.h = source.h; |
| this.lH = source.lH; |
| } |
| |
| /** {@inheritDoc} */ |
| @Override |
| public int sample() { |
| // ALGORITHM 5: SAMPLE |
| int c = 0; |
| int d = 0; |
| for (;;) { |
| // b = flip() |
| // d = 2 * d + (1 - b) |
| d = (d << 1) + flip(); |
| if (d < h[c]) { |
| // z = H[d][c] |
| final int z = lH[d * k + c]; |
| // assert z != NO_LABEL |
| if (z < n) { |
| return z; |
| } |
| d = 0; |
| c = 0; |
| } else { |
| d = d - h[c]; |
| c++; |
| } |
| } |
| } |
| |
| /** |
| * Provides a source of boolean bits. |
| * |
| * <p>Note: This replicates the boolean cache functionality of |
| * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return |
| * an {@code int} value rather than a {@code boolean}. |
| * |
| * @return the bit (0 or 1) |
| */ |
| private int flip() { |
| int bits = booleanSource; |
| if (bits == 1) { |
| // Refill |
| bits = rng.nextInt(); |
| // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit |
| booleanSource = Integer.MIN_VALUE | (bits >>> 1); |
| return bits & 0x1; |
| } |
| // Shift down eventually triggering refill, return current lowest bit |
| booleanSource = bits >>> 1; |
| return bits & 0x1; |
| } |
| |
| /** {@inheritDoc} */ |
| @Override |
| public String toString() { |
| return SAMPLER_NAME + " [" + rng.toString() + "]"; |
| } |
| |
| /** {@inheritDoc} */ |
| @Override |
| public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { |
| return new FLDRSampler(rng, this); |
| } |
| } |
| |
| /** Package-private constructor. */ |
| FastLoadedDiceRollerDiscreteSampler() { |
| // Intentionally empty |
| } |
| |
| /** {@inheritDoc} */ |
| // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler |
| @Override |
| public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng); |
| |
| /** |
| * Creates a sampler. |
| * |
| * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries |
| * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m} |
| * is the sum of the observed frequencies. An exception is raised if this cannot be allocated |
| * as a single array. |
| * |
| * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63. |
| * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042} |
| * when the sum of frequencies is large enough to create k=63. |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param frequencies Observed frequencies of the discrete distribution. |
| * @return the sampler |
| * @throws IllegalArgumentException if {@code frequencies} is null or empty, a |
| * frequency is negative, the sum of all frequencies is either zero or |
| * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree |
| * is too large. |
| */ |
| public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, |
| long[] frequencies) { |
| final long m = sum(frequencies); |
| |
| // Obtain indices of non-zero frequencies |
| final int[] indices = indicesOfNonZero(frequencies); |
| |
| // Edge case for 1 non-zero weight. This also handles edge case for 1 observation |
| // (as log2(m) == 0 will break the computation of the DDG tree). |
| if (indices.length == 1) { |
| return new FixedValueDiscreteSampler(indexOfNonZero(frequencies)); |
| } |
| |
| return createSampler(rng, frequencies, indices, m); |
| } |
| |
| /** |
| * Creates a sampler. |
| * |
| * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2. |
| * The numerators {@code p} are scaled to use a common denominator before summing. |
| * |
| * <p>All weights are used to create the sampler. Weights with a small magnitude relative |
| * to the largest weight can be excluded using the constructor method with the |
| * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}). |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param weights Weights of the discrete distribution. |
| * @return the sampler |
| * @throws IllegalArgumentException if {@code weights} is null or empty, a |
| * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size |
| * of the discrete distribution generating tree is too large. |
| * @see #of(UniformRandomProvider, double[], int) |
| */ |
| public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, |
| double[] weights) { |
| return of(rng, weights, 0); |
| } |
| |
| /** |
| * Creates a sampler. |
| * |
| * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is |
| * a power of 2. The numerators {@code p} are scaled to use a common |
| * denominator before summing. |
| * |
| * <p>Note: The discrete distribution generating (DDG) tree requires |
| * {@code (n + 1) * k} entries where {@code n} is the number of categories, |
| * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators |
| * {@code q}. An exception is raised if this cannot be allocated as a single |
| * array. |
| * |
| * <p>For reference the value {@code k} is equal to or greater than the ratio of |
| * the largest to the smallest weight expressed as a power of 2. For |
| * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value |
| * {@code k} increases with the sum of the weight numerators. A number of |
| * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE} |
| * would be required to raise an exception when the minimum weight is |
| * {@link Double#MIN_VALUE}. |
| * |
| * <p>Weights with a small magnitude relative to the largest weight can be |
| * excluded using the relative magnitude parameter {@code alpha}. This will set |
| * any weight to zero if the magnitude is approximately 2<sup>alpha</sup> |
| * <em>smaller</em> than the largest weight. This comparison is made using only |
| * the exponent of the input weights. The {@code alpha} parameter is ignored if |
| * not above zero. Note that a small {@code alpha} parameter will exclude more |
| * weights than a large {@code alpha} parameter. |
| * |
| * <p>The alpha parameter can be used to exclude categories that |
| * have a very low probability of occurrence and will improve the construction |
| * performance of the sampler. The effect on sampling performance depends on |
| * the relative weights of the excluded categories; typically a high {@code alpha} |
| * is used to exclude categories that would be visited with a very low probability |
| * and the sampling performance is unchanged. |
| * |
| * <p><b>Implementation Note</b> |
| * |
| * <p>This method creates a sampler with <em>exact</em> samples from the |
| * specified probability distribution. It is recommended to use this method: |
| * <ul> |
| * <li>if the weights are computed, for example from a probability mass function; or |
| * <li>if the weights sum to an infinite value. |
| * </ul> |
| * |
| * <p>If the weights are computed from empirical observations then it is |
| * recommended to use the factory method |
| * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This |
| * requires the total number of observations to be representable as a long |
| * integer. |
| * |
| * <p>Note that if all weights are scaled by a power of 2 to be integers, and |
| * each integer can be represented as a positive 64-bit long value, then the |
| * sampler created using this method will match the output from a sampler |
| * created with the scaled weights converted to long values for the factory |
| * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This |
| * assumes the sum of the integer values does not overflow. |
| * |
| * <p>It should be noted that the conversion of weights to rational numbers has |
| * a performance overhead during construction (sampling performance is not |
| * affected). This may be avoided by first converting them to integer values |
| * that can be summed without overflow. For example by scaling values by |
| * {@code 2^62 / sum} and converting to long by casting or rounding. |
| * |
| * <p>This approach may increase the efficiency of construction. The resulting |
| * sampler may no longer produce <em>exact</em> samples from the distribution. |
| * In particular any weights with a converted frequency of zero cannot be |
| * sampled. |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param weights Weights of the discrete distribution. |
| * @param alpha Alpha parameter. |
| * @return the sampler |
| * @throws IllegalArgumentException if {@code weights} is null or empty, a |
| * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, |
| * or the size of the discrete distribution generating tree is too large. |
| * @see #of(UniformRandomProvider, long[]) |
| */ |
| public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, |
| double[] weights, |
| int alpha) { |
| final int n = checkWeightsNonZeroLength(weights); |
| |
| // Convert floating-point double to a relative weight |
| // using a shifted integer representation |
| final long[] frequencies = new long[n]; |
| final int[] offsets = new int[n]; |
| convertToIntegers(weights, frequencies, offsets, alpha); |
| |
| // Obtain indices of non-zero weights |
| final int[] indices = indicesOfNonZero(frequencies); |
| |
| // Edge case for 1 non-zero weight. |
| if (indices.length == 1) { |
| return new FixedValueDiscreteSampler(indexOfNonZero(frequencies)); |
| } |
| |
| final BigInteger m = sum(frequencies, offsets, indices); |
| |
| // Use long arithmetic if possible. This occurs when the weights are similar in magnitude. |
| if (m.compareTo(MAX_LONG) <= 0) { |
| // Apply the offset |
| for (int i = 0; i < n; i++) { |
| frequencies[i] <<= offsets[i]; |
| } |
| return createSampler(rng, frequencies, indices, m.longValue()); |
| } |
| |
| return createSampler(rng, frequencies, offsets, indices, m); |
| } |
| |
| /** |
| * Sum the frequencies. |
| * |
| * @param frequencies Frequencies. |
| * @return the sum |
| * @throws IllegalArgumentException if {@code frequencies} is null or empty, a |
| * frequency is negative, or the sum of all frequencies is either zero or above |
| * {@link Long#MAX_VALUE} |
| */ |
| private static long sum(long[] frequencies) { |
| // Validate |
| if (frequencies == null || frequencies.length == 0) { |
| throw new IllegalArgumentException("frequencies must contain at least 1 value"); |
| } |
| |
| // Sum the values. |
| // Combine all the sign bits in the observations and the intermediate sum in a flag. |
| long m = 0; |
| long signFlag = 0; |
| for (final long o : frequencies) { |
| m += o; |
| signFlag |= o | m; |
| } |
| |
| // Check for a sign-bit. |
| if (signFlag < 0) { |
| // One or more observations were negative, or the sum overflowed. |
| for (final long o : frequencies) { |
| if (o < 0) { |
| throw new IllegalArgumentException("frequencies must contain positive values: " + o); |
| } |
| } |
| throw new IllegalArgumentException("Overflow when summing frequencies"); |
| } |
| if (m == 0) { |
| throw new IllegalArgumentException("Sum of frequencies is zero"); |
| } |
| return m; |
| } |
| |
| /** |
| * Convert the floating-point weights to relative weights represented as |
| * integers {@code value * 2^exponent}. The relative weight as an integer is: |
| * |
| * <pre> |
| * BigInteger.valueOf(value).shiftLeft(exponent) |
| * </pre> |
| * |
| * <p>Note that the weights are created using a common power-of-2 scaling |
| * operation so the minimum exponent is zero. |
| * |
| * <p>A positive {@code alpha} parameter is used to set any weight to zero if |
| * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the |
| * largest weight. This comparison is made using only the exponent of the input |
| * weights. |
| * |
| * @param weights Weights of the discrete distribution. |
| * @param values Output floating-point mantissas converted to 53-bit integers. |
| * @param exponents Output power of 2 exponent. |
| * @param alpha Alpha parameter. |
| * @throws IllegalArgumentException if a weight is negative, infinite or |
| * {@code NaN}, or the sum of all weights is zero. |
| */ |
| private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) { |
| int maxExponent = Integer.MIN_VALUE; |
| for (int i = 0; i < weights.length; i++) { |
| final double weight = weights[i]; |
| // Ignore zero. |
| // When creating the integer value later using bit shifts the result will remain zero. |
| if (weight == 0) { |
| continue; |
| } |
| final long bits = Double.doubleToRawLongBits(weight); |
| |
| // For the IEEE 754 format see Double.longBitsToDouble(long). |
| |
| // Extract the exponent (with the sign bit) |
| int exp = (int) (bits >>> MANTISSA_SIZE); |
| // Detect negative, infinite or NaN. |
| // Note: Negative values sign bit will cause the exponent to be too high. |
| if (exp > MAX_BIASED_EXPONENT) { |
| throw new IllegalArgumentException("Invalid weight: " + weight); |
| } |
| long mantissa; |
| if (exp == 0) { |
| // Sub-normal number: |
| mantissa = (bits & MANTISSA_MASK) << 1; |
| // Here we convert to a normalised number by counting the leading zeros |
| // to obtain the number of shifts of the most significant bit in |
| // the mantissa that is required to get a 1 at position 53 (i.e. as |
| // if it were a normal number with assumed leading bit). |
| final int shift = Long.numberOfLeadingZeros(mantissa << 11); |
| mantissa <<= shift; |
| exp -= shift; |
| } else { |
| // Normal number. Add the implicit leading 1-bit. |
| mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE); |
| } |
| |
| // Here the floating-point number is equal to: |
| // mantissa * 2^(exp-1075) |
| |
| values[i] = mantissa; |
| exponents[i] = exp; |
| maxExponent = Math.max(maxExponent, exp); |
| } |
| |
| // No exponent indicates that all weights are zero |
| if (maxExponent == Integer.MIN_VALUE) { |
| throw new IllegalArgumentException("Sum of weights is zero"); |
| } |
| |
| filterWeights(values, exponents, alpha, maxExponent); |
| scaleWeights(values, exponents); |
| } |
| |
| /** |
| * Filters small weights using the {@code alpha} parameter. |
| * A positive {@code alpha} parameter is used to set any weight to zero if |
| * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the |
| * largest weight. This comparison is made using only the exponent of the input |
| * weights. |
| * |
| * @param values 53-bit values. |
| * @param exponents Power of 2 exponent. |
| * @param alpha Alpha parameter. |
| * @param maxExponent Maximum exponent. |
| */ |
| private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) { |
| if (alpha > 0) { |
| // Filter weights. This must be done before the values are shifted so |
| // the exponent represents the approximate magnitude of the value. |
| for (int i = 0; i < exponents.length; i++) { |
| if (maxExponent - exponents[i] > alpha) { |
| values[i] = 0; |
| } |
| } |
| } |
| } |
| |
| /** |
| * Scale the weights represented as integers {@code value * 2^exponent} to use a |
| * minimum exponent of zero. The values are scaled to remove any common trailing zeros |
| * in their representation. This ultimately reduces the size of the discrete distribution |
| * generating (DGG) tree. |
| * |
| * @param values 53-bit values. |
| * @param exponents Power of 2 exponent. |
| */ |
| private static void scaleWeights(long[] values, int[] exponents) { |
| // Find the minimum exponent and common trailing zeros. |
| int minExponent = Integer.MAX_VALUE; |
| for (int i = 0; i < exponents.length; i++) { |
| if (values[i] != 0) { |
| minExponent = Math.min(minExponent, exponents[i]); |
| } |
| } |
| // Trailing zeros occur when the original weights have a representation with |
| // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}. |
| int trailingZeros = Long.SIZE; |
| for (int i = 0; i < values.length && trailingZeros != 0; i++) { |
| trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i])); |
| } |
| // Scale by a power of 2 so the minimum exponent is zero. |
| for (int i = 0; i < exponents.length; i++) { |
| exponents[i] -= minExponent; |
| } |
| // Remove common trailing zeros. |
| if (trailingZeros != 0) { |
| for (int i = 0; i < values.length; i++) { |
| values[i] >>>= trailingZeros; |
| } |
| } |
| } |
| |
| /** |
| * Sum the integers at the specified indices. |
| * Integers are represented as {@code value * 2^exponent}. |
| * |
| * @param values 53-bit values. |
| * @param exponents Power of 2 exponent. |
| * @param indices Indices to sum. |
| * @return the sum |
| */ |
| private static BigInteger sum(long[] values, int[] exponents, int[] indices) { |
| BigInteger m = BigInteger.ZERO; |
| for (final int i : indices) { |
| m = m.add(toBigInteger(values[i], exponents[i])); |
| } |
| return m; |
| } |
| |
| /** |
| * Convert the value and left shift offset to a BigInteger. |
| * It is assumed the value is at most 53-bits. This allows optimising the left |
| * shift if it is below 11 bits. |
| * |
| * @param value 53-bit value. |
| * @param offset Left shift offset (must be positive). |
| * @return the BigInteger |
| */ |
| private static BigInteger toBigInteger(long value, int offset) { |
| // Ignore zeros. The sum method uses indices of non-zero values. |
| if (offset <= MAX_OFFSET) { |
| // Assume (value << offset) <= Long.MAX_VALUE |
| return BigInteger.valueOf(value << offset); |
| } |
| return BigInteger.valueOf(value).shiftLeft(offset); |
| } |
| |
| /** |
| * Creates the sampler. |
| * |
| * <p>It is assumed the frequencies are all positive and the sum does not |
| * overflow. |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param frequencies Observed frequencies of the discrete distribution. |
| * @param indices Indices of non-zero frequencies. |
| * @param m Sum of the frequencies. |
| * @return the sampler |
| */ |
| private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, |
| long[] frequencies, |
| int[] indices, |
| long m) { |
| // ALGORITHM 5: PREPROCESS |
| // a == frequencies |
| // m = sum(a) |
| // h = leaf node count |
| // H = leaf node label (lH) |
| |
| final int n = frequencies.length; |
| |
| // k = ceil(log2(m)) |
| final int k = 64 - Long.numberOfLeadingZeros(m - 1); |
| // r = a(n+1) = 2^k - m |
| final long r = (1L << k) - m; |
| |
| // Note: |
| // A sparse matrix can often be used for H, as most of its entries are empty. |
| // This implementation uses a 1D array for efficiency at the cost of memory. |
| // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of |
| // observations is large enough to create k=63. |
| // This could be handled using a 2D array. In practice a number of categories this |
| // large is not expected and is currently not supported. |
| final int[] h = new int[k]; |
| final int[] lH = new int[checkArraySize((n + 1L) * k)]; |
| Arrays.fill(lH, NO_LABEL); |
| |
| int d; |
| for (int j = 0; j < k; j++) { |
| final int shift = (k - 1) - j; |
| final long bitMask = 1L << shift; |
| |
| d = 0; |
| for (final int i : indices) { |
| // bool w ← (a[i] >> (k − 1) − j)) & 1 |
| // h[j] = h[j] + w |
| // if w then: |
| if ((frequencies[i] & bitMask) != 0) { |
| h[j]++; |
| // H[d][j] = i |
| lH[d * k + j] = i; |
| d++; |
| } |
| } |
| // process a(n+1) without extending the input frequencies array by 1 |
| if ((r & bitMask) != 0) { |
| h[j]++; |
| lH[d * k + j] = n; |
| } |
| } |
| |
| return new FLDRSampler(rng, n, k, h, lH); |
| } |
| |
| /** |
| * Creates the sampler. Frequencies are are represented as a 53-bit value with a |
| * left-shift offset. |
| * <pre> |
| * BigInteger.valueOf(value).shiftLeft(offset) |
| * </pre> |
| * |
| * <p>It is assumed the frequencies are all positive. |
| * |
| * @param rng Generator of uniformly distributed random numbers. |
| * @param frequencies Observed frequencies of the discrete distribution. |
| * @param offsets Left shift offsets (must be positive). |
| * @param indices Indices of non-zero frequencies. |
| * @param m Sum of the frequencies. |
| * @return the sampler |
| */ |
| private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, |
| long[] frequencies, |
| int[] offsets, |
| int[] indices, |
| BigInteger m) { |
| // Repeat the logic from createSampler(...) using extended arithmetic to test the bits |
| |
| // ALGORITHM 5: PREPROCESS |
| // a == frequencies |
| // m = sum(a) |
| // h = leaf node count |
| // H = leaf node label (lH) |
| |
| final int n = frequencies.length; |
| |
| // k = ceil(log2(m)) |
| final int k = m.subtract(BigInteger.ONE).bitLength(); |
| // r = a(n+1) = 2^k - m |
| final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m); |
| |
| final int[] h = new int[k]; |
| final int[] lH = new int[checkArraySize((n + 1L) * k)]; |
| Arrays.fill(lH, NO_LABEL); |
| |
| int d; |
| for (int j = 0; j < k; j++) { |
| final int shift = (k - 1) - j; |
| |
| d = 0; |
| for (final int i : indices) { |
| // bool w ← (a[i] >> (k − 1) − j)) & 1 |
| // h[j] = h[j] + w |
| // if w then: |
| if (testBit(frequencies[i], offsets[i], shift)) { |
| h[j]++; |
| // H[d][j] = i |
| lH[d * k + j] = i; |
| d++; |
| } |
| } |
| // process a(n+1) without extending the input frequencies array by 1 |
| if (r.testBit(shift)) { |
| h[j]++; |
| lH[d * k + j] = n; |
| } |
| } |
| |
| return new FLDRSampler(rng, n, k, h, lH); |
| } |
| |
| /** |
| * Test the logical bit of the shifted integer representation. |
| * The value is assumed to have at most 53-bits of information. The offset |
| * is assumed to be positive. This is functionally equivalent to: |
| * <pre> |
| * BigInteger.valueOf(value).shiftLeft(offset).testBit(n) |
| * </pre> |
| * |
| * @param value 53-bit value. |
| * @param offset Left shift offset. |
| * @param n Index of bit to test. |
| * @return true if the bit is 1 |
| */ |
| private static boolean testBit(long value, int offset, int n) { |
| if (n < offset) { |
| // All logical trailing bits are zero |
| return false; |
| } |
| // Test if outside the 53-bit value (note that the implicit 1 bit |
| // has been added to the 52-bit mantissas for 'normal' floating-point numbers). |
| final int bit = n - offset; |
| return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0; |
| } |
| |
| /** |
| * Check the weights have a non-zero length. |
| * |
| * @param weights Weights. |
| * @return the length |
| */ |
| private static int checkWeightsNonZeroLength(double[] weights) { |
| if (weights == null || weights.length == 0) { |
| throw new IllegalArgumentException("weights must contain at least 1 value"); |
| } |
| return weights.length; |
| } |
| |
| /** |
| * Create the indices of non-zero values. |
| * |
| * @param values Values. |
| * @return the indices |
| */ |
| private static int[] indicesOfNonZero(long[] values) { |
| int n = 0; |
| final int[] indices = new int[values.length]; |
| for (int i = 0; i < values.length; i++) { |
| if (values[i] != 0) { |
| indices[n++] = i; |
| } |
| } |
| return Arrays.copyOf(indices, n); |
| } |
| |
| /** |
| * Find the index of the first non-zero frequency. |
| * |
| * @param frequencies Frequencies. |
| * @return the index |
| * @throws IllegalStateException if all frequencies are zero. |
| */ |
| static int indexOfNonZero(long[] frequencies) { |
| for (int i = 0; i < frequencies.length; i++) { |
| if (frequencies[i] != 0) { |
| return i; |
| } |
| } |
| throw new IllegalStateException("All frequencies are zero"); |
| } |
| |
| /** |
| * Check the size is valid for a 1D array. |
| * |
| * @param size Size |
| * @return the size as an {@code int} |
| * @throws IllegalArgumentException if the size is too large for a 1D array. |
| */ |
| static int checkArraySize(long size) { |
| if (size > MAX_ARRAY_SIZE) { |
| throw new IllegalArgumentException("Unable to allocate array of size: " + size); |
| } |
| return (int) size; |
| } |
| } |