blob: e1193288c72cc8d60a735dddfc38a50628277917 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
/**
* {@code PTransform}s for taking samples of the elements in a {@code PCollection}, or samples of
* the values associated with each key in a {@code PCollection} of {@code KV}s.
*
* <p>{@link #fixedSizeGlobally(int)} and {@link #fixedSizePerKey(int)} compute uniformly random
* samples. {@link #any(long)} is faster, but provides no uniformity guarantees.
*
* <p>{@link #combineFn} can also be used manually, in combination with state and with the {@link
* Combine} transform.
*/
public class Sample {
/** Returns a {@link CombineFn} that computes a fixed-sized uniform sample of its inputs. */
public static <T> CombineFn<T, ?, Iterable<T>> combineFn(int sampleSize) {
return new FixedSizedSampleFn<>(sampleSize);
}
/**
* Returns a {@link CombineFn} that computes a fixed-sized potentially non-uniform sample of its
* inputs.
*/
public static <T> CombineFn<T, ?, Iterable<T>> anyCombineFn(int sampleSize) {
return new SampleAnyCombineFn<>(sampleSize);
}
/**
* {@code Sample#any(long)} takes a {@code PCollection<T>} and a limit, and produces a new {@code
* PCollection<T>} containing up to limit elements of the input {@code PCollection}.
*
* <p>If limit is greater than or equal to the size of the input {@code PCollection}, then all the
* input's elements will be selected.
*
* <p>Example of use:
*
* <pre>{@code
* PCollection<String> input = ...;
* PCollection<String> output = input.apply(Sample.<String>any(100));
* }</pre>
*
* @param <T> the type of the elements of the input and output {@code PCollection}s
* @param limit the number of elements to take from the input
*/
public static <T> PTransform<PCollection<T>, PCollection<T>> any(long limit) {
return new Any<>(limit);
}
/**
* Returns a {@code PTransform} that takes a {@code PCollection<T>}, selects {@code sampleSize}
* elements, uniformly at random, and returns a {@code PCollection<Iterable<T>>} containing the
* selected elements. If the input {@code PCollection} has fewer than {@code sampleSize} elements,
* then the output {@code Iterable<T>} will be all the input's elements.
*
* <p>All of the elements of the output {@code PCollection} should fit into main memory of a
* single worker machine. This operation does not run in parallel.
*
* <p>Example of use:
*
* <pre>{@code
* PCollection<String> pc = ...;
* PCollection<Iterable<String>> sampleOfSize10 =
* pc.apply(Sample.fixedSizeGlobally(10));
* }</pre>
*
* @param sampleSize the number of elements to select; must be {@code >= 0}
* @param <T> the type of the elements
*/
public static <T> PTransform<PCollection<T>, PCollection<Iterable<T>>> fixedSizeGlobally(
int sampleSize) {
return new FixedSizeGlobally<>(sampleSize);
}
/**
* Returns a {@code PTransform} that takes an input {@code PCollection<KV<K, V>>} and returns a
* {@code PCollection<KV<K, Iterable<V>>>} that contains an output element mapping each distinct
* key in the input {@code PCollection} to a sample of {@code sampleSize} values associated with
* that key in the input {@code PCollection}, taken uniformly at random. If a key in the input
* {@code PCollection} has fewer than {@code sampleSize} values associated with it, then the
* output {@code Iterable<V>} associated with that key will be all the values associated with that
* key in the input {@code PCollection}.
*
* <p>Example of use:
*
* <pre>{@code
* PCollection<KV<String, Integer>> pc = ...;
* PCollection<KV<String, Iterable<Integer>>> sampleOfSize10PerKey =
* pc.apply(Sample.<String, Integer>fixedSizePerKey());
* }</pre>
*
* @param sampleSize the number of values to select for each distinct key; must be {@code >= 0}
* @param <K> the type of the keys
* @param <V> the type of the values
*/
public static <K, V>
PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> fixedSizePerKey(
int sampleSize) {
return new FixedSizePerKey<>(sampleSize);
}
/////////////////////////////////////////////////////////////////////////////
/** Implementation of {@link #any(long)}. */
private static class Any<T> extends PTransform<PCollection<T>, PCollection<T>> {
private final long limit;
/**
* Constructs a {@code SampleAny<T>} PTransform that, when applied, produces a new PCollection
* containing up to {@code limit} elements of its input {@code PCollection}.
*/
private Any(long limit) {
checkArgument(limit >= 0, "Expected non-negative limit, received %s.", limit);
this.limit = limit;
}
@Override
public PCollection<T> expand(PCollection<T> in) {
return in.apply(Combine.globally(new SampleAnyCombineFn<T>(limit)).withoutDefaults())
.apply(Flatten.iterables());
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder.add(DisplayData.item("sampleSize", limit).withLabel("Sample Size"));
}
}
/** Implementation of {@link #fixedSizeGlobally(int)}. */
private static class FixedSizeGlobally<T>
extends PTransform<PCollection<T>, PCollection<Iterable<T>>> {
private final int sampleSize;
private FixedSizeGlobally(int sampleSize) {
this.sampleSize = sampleSize;
}
@Override
public PCollection<Iterable<T>> expand(PCollection<T> input) {
return input.apply(Combine.globally(new FixedSizedSampleFn<>(sampleSize)));
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder.add(DisplayData.item("sampleSize", sampleSize).withLabel("Sample Size"));
}
}
/** Implementation of {@link #fixedSizeGlobally(int)}. */
private static class FixedSizePerKey<K, V>
extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> {
private final int sampleSize;
private FixedSizePerKey(int sampleSize) {
this.sampleSize = sampleSize;
}
@Override
public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
return input.apply(Combine.perKey(new FixedSizedSampleFn<>(sampleSize)));
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder.add(DisplayData.item("sampleSize", sampleSize).withLabel("Sample Size"));
}
}
/** A {@link CombineFn} that combines into a {@link List} of up to limit elements. */
private static class SampleAnyCombineFn<T> extends CombineFn<T, List<T>, Iterable<T>> {
private final long limit;
private SampleAnyCombineFn(long limit) {
this.limit = limit;
}
@Override
public List<T> createAccumulator() {
return new ArrayList<>((int) limit);
}
@Override
public List<T> addInput(List<T> accumulator, T input) {
if (accumulator.size() < limit) {
accumulator.add(input);
}
return accumulator;
}
@Override
public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
Iterator<List<T>> iter = accumulators.iterator();
if (!iter.hasNext()) {
return createAccumulator();
}
List<T> res = iter.next();
while (iter.hasNext()) {
for (T t : iter.next()) {
if (res.size() >= limit) {
return res;
}
res.add(t);
}
}
return res;
}
@Override
public Iterable<T> extractOutput(List<T> accumulator) {
return accumulator;
}
}
/**
* {@code CombineFn} that computes a fixed-size sample of a collection of values.
*
* @param <T> the type of the elements
*/
public static class FixedSizedSampleFn<T>
extends CombineFn<
T, Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>, Iterable<T>> {
private final int sampleSize;
private final Top.TopCombineFn<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>
topCombineFn;
private final Random rand = new Random();
private FixedSizedSampleFn(int sampleSize) {
if (sampleSize < 0) {
throw new IllegalArgumentException("sample size must be >= 0");
}
this.sampleSize = sampleSize;
topCombineFn = new Top.TopCombineFn<>(sampleSize, new KV.OrderByKey<>());
}
@Override
public Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>
createAccumulator() {
return topCombineFn.createAccumulator();
}
@Override
public Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> addInput(
Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> accumulator,
T input) {
accumulator.addInput(KV.of(rand.nextInt(), input));
return accumulator;
}
@Override
public Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>
mergeAccumulators(
Iterable<Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>>
accumulators) {
return topCombineFn.mergeAccumulators(accumulators);
}
@Override
public Iterable<T> extractOutput(
Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> accumulator) {
List<T> out = new ArrayList<>();
for (KV<Integer, T> element : accumulator.extractOutput()) {
out.add(element.getValue());
}
return out;
}
@Override
public Coder<Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>>
getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
return topCombineFn.getAccumulatorCoder(
registry, KvCoder.of(BigEndianIntegerCoder.of(), inputCoder));
}
@Override
public Coder<Iterable<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
return IterableCoder.of(inputCoder);
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder.add(DisplayData.item("sampleSize", sampleSize).withLabel("Sample Size"));
}
}
}