blob: 27350dc9f852192b2a2978103303e1b14942d6af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling.distribution;
import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Stream;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.inference.ChiSquareTest;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.RandomAssert;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
/**
* Test for the {@link FastLoadedDiceRollerDiscreteSampler}.
*/
class FastLoadedDiceRollerDiscreteSamplerTest {
/**
* Creates the sampler.
*
* @param frequencies Observed frequencies.
* @return the FLDR sampler
*/
private static SharedStateDiscreteSampler createSampler(long... frequencies) {
final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
return FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
}
/**
* Creates the sampler.
*
* @param weights Weights.
* @return the FLDR sampler
*/
private static SharedStateDiscreteSampler createSampler(double... weights) {
final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
return FastLoadedDiceRollerDiscreteSampler.of(rng, weights);
}
/**
* Return a stream of invalid frequencies for a discrete distribution.
*
* @return the stream of invalid frequencies
*/
static Stream<long[]> testFactoryConstructorFrequencies() {
return Stream.of(
// Null or empty
(long[]) null,
new long[0],
// Negative
new long[] {-1, 2, 3},
new long[] {1, -2, 3},
new long[] {1, 2, -3},
// Overflow of sum
new long[] {Long.MAX_VALUE, Long.MAX_VALUE},
// x+x+2 == 0
new long[] {Long.MAX_VALUE, Long.MAX_VALUE, 2},
// x+x+x == x - 2 (i.e. positive)
new long[] {Long.MAX_VALUE, Long.MAX_VALUE, Long.MAX_VALUE},
// Zero sum
new long[1],
new long[4]
);
}
@ParameterizedTest
@MethodSource
void testFactoryConstructorFrequencies(long[] frequencies) {
Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(frequencies));
}
/**
* Return a stream of invalid weights for a discrete distribution.
*
* @return the stream of invalid weights
*/
static Stream<double[]> testFactoryConstructorWeights() {
return Stream.of(
// Null or empty
(double[]) null,
new double[0],
// Negative, infinite or NaN
new double[] {-1, 2, 3},
new double[] {1, -2, 3},
new double[] {1, 2, -3},
new double[] {Double.POSITIVE_INFINITY, 2, 3},
new double[] {1, Double.POSITIVE_INFINITY, 3},
new double[] {1, 2, Double.POSITIVE_INFINITY},
new double[] {Double.NaN, 2, 3},
new double[] {1, Double.NaN, 3},
new double[] {1, 2, Double.NaN},
// Zero sum
new double[1],
new double[4]
);
}
@ParameterizedTest
@MethodSource
void testFactoryConstructorWeights(double[] weights) {
Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(weights));
}
@Test
void testToString() {
for (final long[] observed : new long[][] {{42}, {1, 2, 3}}) {
final SharedStateDiscreteSampler sampler = createSampler(observed);
Assertions.assertTrue(sampler.toString().toLowerCase().contains("fast loaded dice roller"));
}
}
@Test
void testSingleCategory() {
final int n = 13;
final int[] expected = new int[n];
Assertions.assertArrayEquals(expected, createSampler(42).samples(n).toArray());
Assertions.assertArrayEquals(expected, createSampler(0.55).samples(n).toArray());
}
@Test
void testSingleFrequency() {
final long[] frequencies = new long[5];
final int category = 2;
frequencies[category] = 1;
final SharedStateDiscreteSampler sampler = createSampler(frequencies);
final int n = 7;
final int[] expected = new int[n];
Arrays.fill(expected, category);
Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
}
@Test
void testSingleWeight() {
final double[] weights = new double[5];
final int category = 3;
weights[category] = 1.5;
final SharedStateDiscreteSampler sampler = createSampler(weights);
final int n = 6;
final int[] expected = new int[n];
Arrays.fill(expected, category);
Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
}
@Test
void testIndexOfNonZero() {
Assertions.assertThrows(IllegalStateException.class,
() -> FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(new long[3]));
final long[] data = new long[3];
for (int i = 0; i < data.length; i++) {
data[i] = 13;
Assertions.assertEquals(i, FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(data));
data[i] = 0;
}
}
@ParameterizedTest
@ValueSource(longs = {0, 1, -1, Integer.MAX_VALUE, 1L << 34})
void testCheckArraySize(long size) {
// This is the same value as the sampler
final int max = Integer.MAX_VALUE - 8;
// Note: The method does not test for negatives.
// This is not required when validating a positive int multiplied by another positive int.
if (size > max) {
Assertions.assertThrows(IllegalArgumentException.class,
() -> FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
} else {
Assertions.assertEquals((int) size, FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
}
}
/**
* Return a stream of expected frequencies for a discrete distribution.
*
* @return the stream of expected frequencies
*/
static Stream<long[]> testSamplesFrequencies() {
return Stream.of(
// Single category
new long[] {0, 0, 42, 0, 0},
// Sum to a power of 2
new long[] {1, 1, 2, 3, 1},
new long[] {0, 1, 1, 0, 2, 3, 1, 0},
// Do not sum to a power of 2
new long[] {1, 2, 3, 1, 3},
new long[] {1, 0, 2, 0, 3, 1, 3},
// Large frequencies
new long[] {5126734627834L, 213267384684832L, 126781236718L, 71289979621378L}
);
}
/**
* Check the distribution of samples match the expected probabilities.
*
* @param expectedFrequencies Expected frequencies.
*/
@ParameterizedTest
@MethodSource
void testSamplesFrequencies(long[] expectedFrequencies) {
final SharedStateDiscreteSampler sampler = createSampler(expectedFrequencies);
final int numberOfSamples = 10000;
final long[] samples = new long[expectedFrequencies.length];
sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
// Handle a test with some zero-probability observations by mapping them out
int mapSize = 0;
double sum = 0;
for (final double f : expectedFrequencies) {
if (f != 0) {
mapSize++;
sum += f;
}
}
// Single category will break the Chi-square test
if (mapSize == 1) {
int index = 0;
while (index < expectedFrequencies.length) {
if (expectedFrequencies[index] != 0) {
break;
}
index++;
}
Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
return;
}
final double[] expected = new double[mapSize];
final long[] observed = new long[mapSize];
for (int i = 0; i < expectedFrequencies.length; i++) {
if (expectedFrequencies[i] != 0) {
--mapSize;
expected[mapSize] = expectedFrequencies[i] / sum;
observed[mapSize] = samples[i];
} else {
Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
}
}
final ChiSquareTest chiSquareTest = new ChiSquareTest();
// Pass if we cannot reject null hypothesis that the distributions are the same.
Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
}
/**
* Return a stream of expected weights for a discrete distribution.
*
* @return the stream of expected weights
*/
static Stream<double[]> testSamplesWeights() {
return Stream.of(
// Single category
new double[] {0, 0, 0.523, 0, 0},
// Sum to a power of 2
new double[] {0.125, 0.125, 0.25, 0.375, 0.125},
new double[] {0, 0.125, 0.125, 0.25, 0, 0.375, 0.125, 0},
// Do not sum to a power of 2
new double[] {0.1, 0.2, 0.3, 0.1, 0.3},
new double[] {0.1, 0, 0.2, 0, 0.3, 0.1, 0.3},
// Sub-normal numbers
new double[] {5 * Double.MIN_NORMAL, 2 * Double.MIN_NORMAL, 3 * Double.MIN_NORMAL, 9 * Double.MIN_NORMAL},
new double[] {2 * Double.MIN_NORMAL, Double.MIN_NORMAL, 0.5 * Double.MIN_NORMAL, 0.75 * Double.MIN_NORMAL},
new double[] {Double.MIN_VALUE, 2 * Double.MIN_VALUE, 3 * Double.MIN_VALUE, 7 * Double.MIN_VALUE},
// Large range of magnitude
new double[] {1.0, 2.0, Math.scalb(3.0, -32), Math.scalb(4.0, -65), 5.0},
new double[] {Math.scalb(1.0, 35), Math.scalb(2.0, 35), Math.scalb(3.0, -32), Math.scalb(4.0, -65), Math.scalb(5.0, 35)},
// Sum to infinite
new double[] {Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE / 2, Double.MAX_VALUE / 4}
);
}
/**
* Check the distribution of samples match the expected weights.
*
* @param weights Category weights.
*/
@ParameterizedTest
@MethodSource
void testSamplesWeights(double[] weights) {
final SharedStateDiscreteSampler sampler = createSampler(weights);
final int numberOfSamples = 10000;
final long[] samples = new long[weights.length];
sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
// Handle a test with some zero-probability observations by mapping them out
int mapSize = 0;
double sum = 0;
// Handle infinite sum using a rolling mean for normalisation
final Mean mean = new Mean();
for (final double w : weights) {
if (w != 0) {
mapSize++;
sum += w;
mean.increment(w);
}
}
// Single category will break the Chi-square test
if (mapSize == 1) {
int index = 0;
while (index < weights.length) {
if (weights[index] != 0) {
break;
}
index++;
}
Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
return;
}
final double mu = mean.getResult();
final int n = mapSize;
final double s = sum;
final DoubleUnaryOperator normalise = Double.isInfinite(sum) ?
x -> (x / mu) * n :
x -> x / s;
final double[] expected = new double[mapSize];
final long[] observed = new long[mapSize];
for (int i = 0; i < weights.length; i++) {
if (weights[i] != 0) {
--mapSize;
expected[mapSize] = normalise.applyAsDouble(weights[i]);
observed[mapSize] = samples[i];
} else {
Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
}
}
final ChiSquareTest chiSquareTest = new ChiSquareTest();
// Pass if we cannot reject null hypothesis that the distributions are the same.
Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
}
/**
* Check the distribution of samples when the frequencies can be converted to weights without
* loss of precision.
*
* @param frequencies Expected frequencies.
*/
@ParameterizedTest
@MethodSource(value = {"testSamplesFrequencies"})
void testSamplesWeightsMatchesFrequencies(long[] frequencies) {
final double[] weights = new double[frequencies.length];
for (int i = 0; i < frequencies.length; i++) {
final double w = frequencies[i];
Assumptions.assumeTrue((long) w == frequencies[i]);
// Ensure the exponent is set in the event of simple frequencies
weights[i] = Math.scalb(w, -35);
}
final long seed = RandomSource.createLong();
final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
final SharedStateDiscreteSampler sampler1 =
FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
final SharedStateDiscreteSampler sampler2 =
FastLoadedDiceRollerDiscreteSampler.of(rng2, weights);
RandomAssert.assertProduceSameSequence(sampler1, sampler2);
}
/**
* Test scaled weights. The sampler uses the relative magnitude of weights and the
* output should be invariant to scaling. The weights are sampled from the 2^53 dyadic
* rationals in [0, 1). A scale factor of -1021 is the lower limit if a weight is
* 2^-53 to maintain a non-zero weight. The upper limit is 1023 if a weight is 1 to avoid
* infinite values. Note that it does not matter if the sum of weights is infinite; only
* the individual weights must be finite.
*
* @param scaleFactor the scale factor
*/
@ParameterizedTest
@ValueSource(ints = {1023, 67, 1, -59, -1020, -1021})
void testScaledWeights(int scaleFactor) {
// Weights in [0, 1)
final double[] w1 = RandomSource.KISS.create().doubles(10).toArray();
final double scale = Math.scalb(1.0, scaleFactor);
final double[] w2 = Arrays.stream(w1).map(x -> x * scale).toArray();
final long seed = RandomSource.createLong();
final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
final SharedStateDiscreteSampler sampler1 =
FastLoadedDiceRollerDiscreteSampler.of(rng1, w1);
final SharedStateDiscreteSampler sampler2 =
FastLoadedDiceRollerDiscreteSampler.of(rng2, w2);
RandomAssert.assertProduceSameSequence(sampler1, sampler2);
}
/**
* Test the alpha parameter removes small relative weights.
* Weights should be removed if they are {@code 2^alpha} smaller than the largest
* weight.
*
* @param alpha Alpha parameter
*/
@ParameterizedTest
@ValueSource(ints = {13, 30, 53})
void testAlphaRemovesWeights(int alpha) {
// The small weight must be > 2^alpha smaller so scale by (alpha + 1)
final double small = Math.scalb(1.0, -(alpha + 1));
final double[] w1 = {1, 0.5, 0.5, 0};
final double[] w2 = {1, 0.5, 0.5, small};
final long seed = RandomSource.createLong();
final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
final UniformRandomProvider rng3 = RandomSource.SPLIT_MIX_64.create(seed);
final int n = 10;
final int[] s1 = FastLoadedDiceRollerDiscreteSampler.of(rng1, w1).samples(n).toArray();
final int[] s2 = FastLoadedDiceRollerDiscreteSampler.of(rng2, w2, alpha).samples(n).toArray();
final int[] s3 = FastLoadedDiceRollerDiscreteSampler.of(rng3, w2, alpha + 1).samples(n).toArray();
Assertions.assertArrayEquals(s1, s2, "alpha parameter should ignore the small weight");
Assertions.assertFalse(Arrays.equals(s1, s3), "alpha+1 parameter should not ignore the small weight");
}
static Stream<long[]> testSharedStateSampler() {
return Stream.of(
new long[] {42},
new long[] {1, 1, 2, 3, 1}
);
}
@ParameterizedTest
@MethodSource
void testSharedStateSampler(long[] frequencies) {
final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L);
final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L);
final SharedStateDiscreteSampler sampler1 =
FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
RandomAssert.assertProduceSameSequence(sampler1, sampler2);
}
}