blob: e5ffc6a6e279551d39601ce953b6db2e64dafec6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.junit.Assert;
import org.junit.Test;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
/**
* Test class for {@link DiscreteProbabilityCollectionSampler}.
*/
public class DiscreteProbabilityCollectionSamplerTest {
/** RNG. */
private final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A);
@Test(expected = IllegalArgumentException.class)
public void testPrecondition1() {
// Size mismatch
new DiscreteProbabilityCollectionSampler<Double>(rng,
Arrays.asList(new Double[] {1d, 2d}),
new double[] {0d});
}
@Test(expected = IllegalArgumentException.class)
public void testPrecondition2() {
// Negative probability
new DiscreteProbabilityCollectionSampler<Double>(rng,
Arrays.asList(new Double[] {1d, 2d}),
new double[] {0d, -1d});
}
@Test(expected = IllegalArgumentException.class)
public void testPrecondition3() {
// Probabilities do not sum above 0
new DiscreteProbabilityCollectionSampler<Double>(rng,
Arrays.asList(new Double[] {1d, 2d}),
new double[] {0d, 0d});
}
@Test(expected = IllegalArgumentException.class)
public void testPrecondition4() {
// NaN probability
new DiscreteProbabilityCollectionSampler<Double>(rng,
Arrays.asList(new Double[] {1d, 2d}),
new double[] {0d, Double.NaN});
}
@Test(expected = IllegalArgumentException.class)
public void testPrecondition5() {
// Infinite probability
new DiscreteProbabilityCollectionSampler<Double>(rng,
Arrays.asList(new Double[] {1d, 2d}),
new double[] {0d, Double.POSITIVE_INFINITY});
}
@Test(expected = IllegalArgumentException.class)
public void testPrecondition6() {
// Empty Map<T, Double> not allowed
new DiscreteProbabilityCollectionSampler<Double>(rng,
new HashMap<Double, Double>());
}
@Test(expected = IllegalArgumentException.class)
public void testPrecondition7() {
// Empty List<T> not allowed
new DiscreteProbabilityCollectionSampler<Double>(rng,
Collections.<Double>emptyList(),
new double[0]);
}
@Test
public void testSample() {
final DiscreteProbabilityCollectionSampler<Double> sampler =
new DiscreteProbabilityCollectionSampler<Double>(rng,
Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}),
new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
final double expectedMean = 3.4;
final double expectedVariance = 7.84;
final int n = 100000000;
double sum = 0;
double sumOfSquares = 0;
for (int i = 0; i < n; i++) {
final double rand = sampler.sample();
sum += rand;
sumOfSquares += rand * rand;
}
final double mean = sum / n;
Assert.assertEquals(expectedMean, mean, 1e-3);
final double variance = sumOfSquares / n - mean * mean;
Assert.assertEquals(expectedVariance, variance, 2e-3);
}
@Test
public void testSampleUsingMap() {
final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities);
// Create a map version. The map iterator must be ordered so use a TreeMap.
final Map<Integer, Double> map = new TreeMap<Integer, Double>();
for (int i = 0; i < probabilities.length; i++) {
map.put(items.get(i), probabilities[i]);
}
final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
new DiscreteProbabilityCollectionSampler<Integer>(rng2, map);
for (int i = 0; i < 50; i++) {
Assert.assertEquals(sampler1.sample(), sampler2.sample());
}
}
/**
* Edge-case test:
* Create a sampler that will return 1 for nextDouble() forcing the search to
* identify the end item of the cumulative probability array.
*/
@Test
public void testSampleWithProbabilityAtLastItem() {
// Ensure the samples pick probability 0 (the first item) and then
// a probability (for the second item) that hits an edge case.
final UniformRandomProvider dummyRng = new UniformRandomProvider() {
private int count;
// CHECKSTYLE: stop all
public long nextLong(long n) { return 0; }
public long nextLong() { return 0; }
public int nextInt(int n) { return 0; }
public int nextInt() { return 0; }
public float nextFloat() { return 0; }
// Return 0 then the given probability
public double nextDouble() { return (count++ == 0) ? 0 : 1.0; }
public void nextBytes(byte[] bytes, int start, int len) {}
public void nextBytes(byte[] bytes) {}
public boolean nextBoolean() { return false; }
// CHECKSTYLE: resume all
};
final List<Double> items = Arrays.asList(new Double[] {1d, 2d});
final DiscreteProbabilityCollectionSampler<Double> sampler =
new DiscreteProbabilityCollectionSampler<Double>(dummyRng,
items,
new double[] {0.5, 0.5});
final Double item1 = sampler.sample();
final Double item2 = sampler.sample();
// Check they are in the list
Assert.assertTrue("Sample item1 is not from the list", items.contains(item1));
Assert.assertTrue("Sample item2 is not from the list", items.contains(item2));
// Test the two samples are different items
Assert.assertNotSame("Item1 and 2 should be different", item1, item2);
}
/**
* Test the SharedStateSampler implementation.
*/
@Test
public void testSharedStateSampler() {
final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
final List<Double> items = Arrays.asList(new Double[] {1d, 2d, 3d, 4d});
final DiscreteProbabilityCollectionSampler<Double> sampler1 =
new DiscreteProbabilityCollectionSampler<Double>(rng1,
items,
new double[] {0.1, 0.2, 0.3, 0.4});
final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
RandomAssert.assertProduceSameSequence(
new RandomAssert.Sampler<Double>() {
@Override
public Double sample() {
return sampler1.sample();
}
},
new RandomAssert.Sampler<Double>() {
@Override
public Double sample() {
return sampler2.sample();
}
});
}
}