blob: 86e6358f6e6e058ee8ca49c3a07f8354d4e4b8f3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling.distribution;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.SharedStateObjectSampler;
/**
* Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
* distribution</a>.
*
* <p>Sampling uses:</p>
*
* <ul>
* <li>{@link UniformRandomProvider#nextLong()}
* <li>{@link UniformRandomProvider#nextDouble()}
* </ul>
*
* @since 1.4
*/
public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
/** The minimum number of categories. */
private static final int MIN_CATGEORIES = 2;
/** RNG (used for the toString() method). */
private final UniformRandomProvider rng;
/**
* Sample from a Dirichlet distribution with different concentration parameters
* for each category.
*/
private static final class GeneralDirichletSampler extends DirichletSampler {
/** Samplers for each category. */
private final SharedStateContinuousSampler[] samplers;
/**
* @param rng Generator of uniformly distributed random numbers.
* @param samplers Samplers for each category.
*/
GeneralDirichletSampler(UniformRandomProvider rng,
SharedStateContinuousSampler[] samplers) {
super(rng);
// Array is stored directly as it is generated within the DirichletSampler class
this.samplers = samplers;
}
@Override
protected int getK() {
return samplers.length;
}
@Override
protected double nextGamma(int i) {
return samplers[i].sample();
}
@Override
public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
for (int i = 0; i < newSamplers.length; i++) {
newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
}
return new GeneralDirichletSampler(rng, newSamplers);
}
}
/**
* Sample from a symmetric Dirichlet distribution with the same concentration parameter
* for each category.
*/
private static final class SymmetricDirichletSampler extends DirichletSampler {
/** Number of categories. */
private final int k;
/** Sampler for the categories. */
private final SharedStateContinuousSampler sampler;
/**
* @param rng Generator of uniformly distributed random numbers.
* @param k Number of categories.
* @param sampler Sampler for the categories.
*/
SymmetricDirichletSampler(UniformRandomProvider rng,
int k,
SharedStateContinuousSampler sampler) {
super(rng);
this.k = k;
this.sampler = sampler;
}
@Override
protected int getK() {
return k;
}
@Override
protected double nextGamma(int i) {
return sampler.sample();
}
@Override
public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
}
}
/**
* @param rng Generator of uniformly distributed random numbers.
*/
private DirichletSampler(UniformRandomProvider rng) {
this.rng = rng;
}
/** {@inheritDoc} */
@Override
public String toString() {
return "Dirichlet deviate [" + rng.toString() + "]";
}
@Override
public double[] sample() {
// Create Gamma(alpha_i, 1) deviates for all alpha
final double[] y = new double[getK()];
double norm = 0;
for (int i = 0; i < y.length; i++) {
final double yi = nextGamma(i);
norm += yi;
y[i] = yi;
}
// Normalize by dividing by the sum of the samples
norm = 1.0 / norm;
// Detect an invalid normalization, e.g. cases of all zero samples
if (!isNonZeroPositiveFinite(norm)) {
// Sample again using recursion.
// A stack overflow due to a broken RNG will eventually occur
// rather than the alternative which is an infinite loop.
return sample();
}
// Normalise
for (int i = 0; i < y.length; i++) {
y[i] *= norm;
}
return y;
}
/**
* Gets the number of categories.
*
* @return k
*/
protected abstract int getK();
/**
* Create a gamma sample for the given category.
*
* @param category Category.
* @return the sample
*/
protected abstract double nextGamma(int category);
/** {@inheritDoc} */
// Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
@Override
public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
/**
* Creates a new Dirichlet distribution sampler.
*
* @param rng Generator of uniformly distributed random numbers.
* @param alpha Concentration parameters.
* @return the sampler
* @throws IllegalArgumentException if the number of concentration parameters
* is less than 2; or if any concentration parameter is not strictly positive.
*/
public static DirichletSampler of(UniformRandomProvider rng,
double... alpha) {
validateNumberOfCategories(alpha.length);
final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
for (int i = 0; i < samplers.length; i++) {
samplers[i] = createSampler(rng, alpha[i]);
}
return new GeneralDirichletSampler(rng, samplers);
}
/**
* Creates a new symmetric Dirichlet distribution sampler using the same concentration
* parameter for each category.
*
* @param rng Generator of uniformly distributed random numbers.
* @param k Number of categories.
* @param alpha Concentration parameter.
* @return the sampler
* @throws IllegalArgumentException if the number of categories is
* less than 2; or if the concentration parameter is not strictly positive.
*/
public static DirichletSampler symmetric(UniformRandomProvider rng,
int k,
double alpha) {
validateNumberOfCategories(k);
final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
return new SymmetricDirichletSampler(rng, k, sampler);
}
/**
* Validate the number of categories.
*
* @param k Number of categories.
* @throws IllegalArgumentException if the number of categories is
* less than 2.
*/
private static void validateNumberOfCategories(int k) {
if (k < MIN_CATGEORIES) {
throw new IllegalArgumentException("Invalid number of categories: " + k);
}
}
/**
* Creates a gamma sampler for a category with the given concentration parameter.
*
* @param rng Generator of uniformly distributed random numbers.
* @param alpha Concentration parameter.
* @return the sampler
* @throws IllegalArgumentException if the concentration parameter is not strictly positive.
*/
private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
double alpha) {
// Negation of logic will detect NaN
if (!isNonZeroPositiveFinite(alpha)) {
throw new IllegalArgumentException("Invalid concentration: " + alpha);
}
// Create a Gamma(shape=alpha, scale=1) sampler.
if (alpha == 1) {
// Special case
// Gamma(shape=1, scale=1) == Exponential(mean=1)
return ZigguratSampler.Exponential.of(rng);
}
return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
}
/**
* Return true if the value is non-zero, positive and finite.
*
* @param x Value.
* @return true if non-zero positive finite
*/
private static boolean isNonZeroPositiveFinite(double x) {
return x > 0 && x < Double.POSITIVE_INFINITY;
}
}