blob: 272f17ef390d66740480a9bde41e60bc1d1a9a3a [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package accord.utils.random;
import java.util.Arrays;
import java.util.function.Supplier;
import accord.utils.Invariants;
import accord.utils.RandomSource;
public class Picker
public static float[] randomWeights(RandomSource random, int length)
float[] weights = new float[length - 1];
float sum = 0;
for (int i = 0 ; i < weights.length ; ++i)
weights[i] = sum += random.nextFloat();
sum += random.nextFloat();
for (int i = 0 ; i < weights.length ; ++i)
weights[i] /= sum;
return weights;
static abstract class Weighted
final RandomSource random;
final float[] weights;
public Weighted(RandomSource random, int length)
this(random, Picker.randomWeights(random, length));
public Weighted(RandomSource random, float[] weights)
this.random = random;
this.weights = weights;
static float[] randomWeights(RandomSource random, float[] bias)
float[] weights = new float[bias.length - 1];
float sum = 0;
for (int i = 0 ; i < weights.length ; ++i)
weights[i] = sum += random.nextFloat() * bias[i];
sum += random.nextFloat() * bias[weights.length];
for (int i = 0 ; i < weights.length ; ++i)
weights[i] /= sum;
return weights;
static float[] normaliseWeights(float[] input)
float[] output = new float[input.length - 1];
float sum = 0;
for (int i = 0 ; i < output.length ; ++i)
output[i] = sum += input[i];
sum += input[output.length];
for (int i = 0 ; i < output.length ; ++i)
output[i] /= sum;
return output;
int pickIndex()
int i = Arrays.binarySearch(weights, random.nextFloat());
if (i < 0) i = -1 - i;
return i;
public static class WeightedObjectPicker<T> extends Weighted implements Supplier<T>
final T[] values;
private WeightedObjectPicker(RandomSource random, T[] values, float[] weights)
super(random, values.length);
this.values = values;
public T get()
return values[pickIndex()];
public static <T> WeightedObjectPicker<T> randomWeighted(RandomSource random, T[] values)
return new WeightedObjectPicker<>(random, values, Picker.randomWeights(random, values.length));
public static <T> WeightedObjectPicker<T> randomWeighted(RandomSource random, T[] values, float[] bias)
Invariants.checkArgument(values.length == bias.length);
return new WeightedObjectPicker<>(random, values, randomWeights(random, bias));
public static <T> WeightedObjectPicker<T> weighted(RandomSource random, T[] values, float[] proportionalWeights)
Invariants.checkArgument(values.length == proportionalWeights.length);
return new WeightedObjectPicker<>(random, values, normaliseWeights(proportionalWeights));