blob: 7b80d36c8a1767e3a139c4a20fbe6db07cdf63d0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing sampleissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import org.apache.commons.math3.stat.inference.ChiSquareTest;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
/**
* Tests for {@link CombinationSampler}.
*/
public class CombinationSamplerTest {
private final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S);
@Test
public void testSampleIsInDomain() {
final int n = 6;
for (int k = 1; k <= n; k++) {
final CombinationSampler sampler = new CombinationSampler(rng, n, k);
final int[] random = sampler.sample();
for (int s : random) {
assertIsInDomain(n, s);
}
}
}
@Test
public void testUniformWithKlessThanHalfN() {
final int n = 8;
final int k = 2;
assertUniformSamples(n, k);
}
@Test
public void testUniformWithKmoreThanHalfN() {
final int n = 8;
final int k = 6;
assertUniformSamples(n, k);
}
@Test
public void testSampleWhenNequalsKIsNotShuffled() {
// Check n == k boundary case.
// This is allowed but the sample is not shuffled.
for (int n = 1; n < 3; n++) {
final int k = n;
final CombinationSampler sampler = new CombinationSampler(rng, n, k);
final int[] sample = sampler.sample();
Assert.assertEquals("Incorrect sample length", n, sample.length);
for (int i = 0; i < n; i++) {
Assert.assertEquals("Sample was shuffled", i, sample[i]);
}
}
}
@Test(expected = IllegalArgumentException.class)
public void testKgreaterThanNThrows() {
// Must fail for k > n.
final int n = 2;
final int k = 3;
new CombinationSampler(rng, n, k);
}
@Test(expected = IllegalArgumentException.class)
public void testNequalsZeroThrows() {
// Must fail for n = 0.
final int n = 0;
final int k = 3;
new CombinationSampler(rng, n, k);
}
@Test(expected = IllegalArgumentException.class)
public void testKequalsZeroThrows() {
// Must fail for k = 0.
final int n = 2;
final int k = 0;
new CombinationSampler(rng, n, k);
}
@Test(expected = IllegalArgumentException.class)
public void testNisNegativeThrows() {
// Must fail for n <= 0.
final int n = -1;
final int k = 3;
new CombinationSampler(rng, n, k);
}
@Test(expected = IllegalArgumentException.class)
public void testKisNegativeThrows() {
// Must fail for k <= 0.
final int n = 0;
final int k = -1;
new CombinationSampler(rng, n, k);
}
//// Support methods.
/**
* Asserts the sample value is in the range 0 to n-1.
*
* @param n the n
* @param value the sample value
*/
private static final void assertIsInDomain(int n, int value) {
if (value < 0 || value >= n) {
Assert.fail("sample " + value + " not in the domain " + n);
}
}
private void assertUniformSamples(int n, int k) {
// The C(n, k) should generate a sample of unspecified order.
// To test this each combination is allocated a unique code
// based on setting k of the first n-bits in an integer.
// Codes are positive for all combinations of bits that use k-bits,
// otherwise they are negative.
final int totalBitCombinations = 1 << n;
int[] codeLookup = new int[totalBitCombinations];
Arrays.fill(codeLookup, -1); // initialise as negative
int codes = 0;
for (int i = 0; i < totalBitCombinations; i++) {
if (Integer.bitCount(i) == k) {
// This is a valid sample so allocate a code
codeLookup[i] = codes++;
}
}
// The number of combinations C(n, k) is the binomial coefficient
Assert.assertEquals("Incorrect number of combination codes",
CombinatoricsUtils.binomialCoefficient(n, k), codes);
final long[] observed = new long[codes];
final int numSamples = 6000;
final CombinationSampler sampler = new CombinationSampler(rng, n, k);
for (int i = 0; i < numSamples; i++) {
observed[findCode(codeLookup, sampler.sample())]++;
}
// Chi squared test of uniformity
final double numExpected = numSamples / (double) codes;
final double[] expected = new double[codes];
Arrays.fill(expected, numExpected);
final ChiSquareTest chiSquareTest = new ChiSquareTest();
// Pass if we cannot reject null hypothesis that distributions are the same.
Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
}
private static int findCode(int[] codeLookup, int[] sample) {
// Each sample index is used to set a bit in an integer.
// The resulting bits should be a valid code.
int bits = 0;
for (int s : sample) {
// This shift will be from 0 to n-1 since it is from the
// domain of size n.
bits |= (1 << s);
}
if (bits >= codeLookup.length) {
Assert.fail("Bad bit combination: " + Arrays.toString(sample));
}
final int code = codeLookup[bits];
if (code < 0) {
Assert.fail("Bad bit code: " + Arrays.toString(sample));
}
return code;
}
}