blob: 383c1f82c91ad5891122b744e24a3c0adcc34050 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.testing;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.hamcrest.Matcher;
/**
* Utilities for testing {@link CombineFn CombineFns}. Ensures that the {@link CombineFn} gives
* correct results across various permutations and shardings of the input.
*/
public class CombineFnTester {
/**
* Tests that the the {@link CombineFn}, when applied to the provided input, produces the provided
* output. Tests a variety of permutations of the input.
*/
public static <InputT, AccumT, OutputT> void testCombineFn(
CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, final OutputT expected) {
testCombineFn(fn, input, is(expected));
Collections.shuffle(input);
testCombineFn(fn, input, is(expected));
}
public static <InputT, AccumT, OutputT> void testCombineFn(
CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, Matcher<? super OutputT> matcher) {
int size = input.size();
checkCombineFnShardsMultipleOrders(fn, Collections.singletonList(input), matcher);
checkCombineFnShardsMultipleOrders(fn, shardEvenly(input, 2), matcher);
if (size > 4) {
checkCombineFnShardsMultipleOrders(fn, shardEvenly(input, size / 2), matcher);
checkCombineFnShardsMultipleOrders(
fn, shardEvenly(input, (int) (size / Math.sqrt(size))), matcher);
}
checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, 1.4), matcher);
checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, 2), matcher);
checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, Math.E), matcher);
}
private static <InputT, AccumT, OutputT> void checkCombineFnShardsMultipleOrders(
CombineFn<InputT, AccumT, OutputT> fn,
List<? extends Iterable<InputT>> shards,
Matcher<? super OutputT> matcher) {
checkCombineFnShardsSingleMerge(fn, shards, matcher);
checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);
checkCombineFnShardsIncrementalMerging(fn, shards, matcher);
Collections.shuffle(shards);
checkCombineFnShardsSingleMerge(fn, shards, matcher);
checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);
checkCombineFnShardsIncrementalMerging(fn, shards, matcher);
}
private static <InputT, AccumT, OutputT> void checkCombineFnShardsSingleMerge(
CombineFn<InputT, AccumT, OutputT> fn,
Iterable<? extends Iterable<InputT>> shards,
Matcher<? super OutputT> matcher) {
List<AccumT> accumulators = combineInputs(fn, shards);
AccumT merged = fn.mergeAccumulators(accumulators);
assertThat(fn.extractOutput(merged), matcher);
}
private static <InputT, AccumT, OutputT> void checkCombineFnShardsWithEmptyAccumulators(
CombineFn<InputT, AccumT, OutputT> fn,
Iterable<? extends Iterable<InputT>> shards,
Matcher<? super OutputT> matcher) {
List<AccumT> accumulators = combineInputs(fn, shards);
accumulators.add(0, fn.createAccumulator());
accumulators.add(fn.createAccumulator());
AccumT merged = fn.mergeAccumulators(accumulators);
assertThat(fn.extractOutput(merged), matcher);
}
private static <InputT, AccumT, OutputT> void checkCombineFnShardsIncrementalMerging(
CombineFn<InputT, AccumT, OutputT> fn,
List<? extends Iterable<InputT>> shards,
Matcher<? super OutputT> matcher) {
AccumT accumulator = shards.isEmpty() ? fn.createAccumulator() : null;
for (AccumT inputAccum : combineInputs(fn, shards)) {
if (accumulator == null) {
accumulator = inputAccum;
} else {
accumulator = fn.mergeAccumulators(Arrays.asList(accumulator, inputAccum));
}
fn.extractOutput(accumulator); // Extract output to simulate multiple firings
}
assertThat(fn.extractOutput(accumulator), matcher);
}
private static <InputT, AccumT, OutputT> List<AccumT> combineInputs(
CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards) {
List<AccumT> accumulators = new ArrayList<>();
int maybeCompact = 0;
for (Iterable<InputT> shard : shards) {
AccumT accumulator = fn.createAccumulator();
for (InputT elem : shard) {
accumulator = fn.addInput(accumulator, elem);
}
if (maybeCompact++ % 2 == 0) {
accumulator = fn.compact(accumulator);
}
accumulators.add(accumulator);
}
return accumulators;
}
private static <T> List<List<T>> shardEvenly(List<T> input, int numShards) {
List<List<T>> shards = new ArrayList<>(numShards);
for (int i = 0; i < numShards; i++) {
shards.add(input.subList(i * input.size() / numShards, (i + 1) * input.size() / numShards));
}
return shards;
}
private static <T> List<List<T>> shardExponentially(List<T> input, double base) {
assert base > 1.0;
List<List<T>> shards = new ArrayList<>();
int end = input.size();
while (end > 0) {
int start = (int) (end / base);
shards.add(input.subList(start, end));
end = start;
}
return shards;
}
}