| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.transforms; |
| |
| import static org.apache.beam.sdk.TestUtils.KvMatcher.isKv; |
| import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; |
| import static org.hamcrest.CoreMatchers.equalTo; |
| import static org.hamcrest.CoreMatchers.hasItem; |
| import static org.hamcrest.Matchers.empty; |
| import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; |
| import static org.junit.Assert.assertThat; |
| |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.concurrent.ThreadLocalRandom; |
| import java.util.stream.Collectors; |
| import java.util.stream.StreamSupport; |
| import org.apache.beam.sdk.coders.AtomicCoder; |
| import org.apache.beam.sdk.coders.BigEndianIntegerCoder; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.CoderProviders; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.coders.MapCoder; |
| import org.apache.beam.sdk.coders.StringUtf8Coder; |
| import org.apache.beam.sdk.coders.VarIntCoder; |
| import org.apache.beam.sdk.testing.DataflowPortabilityApiUnsupported; |
| import org.apache.beam.sdk.testing.LargeKeys; |
| import org.apache.beam.sdk.testing.NeedsRunner; |
| import org.apache.beam.sdk.testing.PAssert; |
| import org.apache.beam.sdk.testing.TestPipeline; |
| import org.apache.beam.sdk.testing.TestStream; |
| import org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime; |
| import org.apache.beam.sdk.testing.ValidatesRunner; |
| import org.apache.beam.sdk.transforms.display.DisplayData; |
| import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; |
| import org.apache.beam.sdk.transforms.windowing.FixedWindows; |
| import org.apache.beam.sdk.transforms.windowing.GlobalWindow; |
| import org.apache.beam.sdk.transforms.windowing.IntervalWindow; |
| import org.apache.beam.sdk.transforms.windowing.InvalidWindows; |
| import org.apache.beam.sdk.transforms.windowing.Repeatedly; |
| import org.apache.beam.sdk.transforms.windowing.Sessions; |
| import org.apache.beam.sdk.transforms.windowing.SlidingWindows; |
| import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; |
| import org.apache.beam.sdk.transforms.windowing.Window; |
| import org.apache.beam.sdk.values.KV; |
| import org.apache.beam.sdk.values.PBegin; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.TimestampedValue; |
| import org.apache.beam.sdk.values.WindowingStrategy; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList; |
| import org.hamcrest.Matcher; |
| import org.joda.time.Duration; |
| import org.joda.time.Instant; |
| import org.junit.Assert; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.experimental.categories.Category; |
| import org.junit.rules.ExpectedException; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| |
| /** Tests for GroupByKey. */ |
| @SuppressWarnings({"rawtypes", "unchecked"}) |
| public class GroupByKeyTest implements Serializable { |
| /** Shared test base class with setup/teardown helpers. */ |
| public abstract static class SharedTestBase { |
| @Rule public transient TestPipeline p = TestPipeline.create(); |
| |
| @Rule public transient ExpectedException thrown = ExpectedException.none(); |
| } |
| |
| /** Tests validating basic {@link GroupByKey} scenarios. */ |
| @RunWith(JUnit4.class) |
| public static class BasicTests extends SharedTestBase { |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testGroupByKey() { |
| List<KV<String, Integer>> ungroupedPairs = |
| Arrays.asList( |
| KV.of("k1", 3), |
| KV.of("k5", Integer.MAX_VALUE), |
| KV.of("k5", Integer.MIN_VALUE), |
| KV.of("k2", 66), |
| KV.of("k1", 4), |
| KV.of("k2", -33), |
| KV.of("k3", 0)); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.create()); |
| |
| SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, Void> checker = |
| containsKvs( |
| kv("k1", 3, 4), |
| kv("k5", Integer.MIN_VALUE, Integer.MAX_VALUE), |
| kv("k2", 66, -33), |
| kv("k3", 0)); |
| PAssert.that(output).satisfies(checker); |
| PAssert.that(output).inWindow(GlobalWindow.INSTANCE).satisfies(checker); |
| |
| p.run(); |
| } |
| |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testGroupByKeyEmpty() { |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.create()); |
| |
| PAssert.that(output).empty(); |
| |
| p.run(); |
| } |
| |
| /** |
| * Tests that when a processing time timers comes in after a window is expired it does not cause |
| * a spurious output. |
| */ |
| @Test |
| @Category({ValidatesRunner.class, UsesTestStreamWithProcessingTime.class}) |
| public void testCombiningAccumulatingProcessingTime() throws Exception { |
| PCollection<Integer> triggeredSums = |
| p.apply( |
| TestStream.create(VarIntCoder.of()) |
| .advanceWatermarkTo(new Instant(0)) |
| .addElements( |
| TimestampedValue.of(2, new Instant(2)), |
| TimestampedValue.of(5, new Instant(5))) |
| .advanceWatermarkTo(new Instant(100)) |
| .advanceProcessingTime(Duration.millis(10)) |
| .advanceWatermarkToInfinity()) |
| .apply( |
| Window.<Integer>into(FixedWindows.of(Duration.millis(100))) |
| .withTimestampCombiner(TimestampCombiner.EARLIEST) |
| .accumulatingFiredPanes() |
| .withAllowedLateness(Duration.ZERO) |
| .triggering( |
| Repeatedly.forever( |
| AfterProcessingTime.pastFirstElementInPane() |
| .plusDelayOf(Duration.millis(10))))) |
| .apply(Sum.integersGlobally().withoutDefaults()); |
| |
| PAssert.that(triggeredSums).containsInAnyOrder(7); |
| |
| p.run(); |
| } |
| |
| @Test |
| public void testGroupByKeyNonDeterministic() throws Exception { |
| |
| List<KV<Map<String, String>, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<Map<String, String>, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder( |
| KvCoder.of( |
| MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), |
| BigEndianIntegerCoder.of()))); |
| |
| thrown.expect(IllegalStateException.class); |
| thrown.expectMessage("must be deterministic"); |
| input.apply(GroupByKey.create()); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testRemerge() { |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.into(Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| |
| PCollection<KV<String, Iterable<Iterable<Integer>>>> middle = |
| input |
| .apply("GroupByKey", GroupByKey.create()) |
| .apply("Remerge", Window.remerge()) |
| .apply("GroupByKeyAgain", GroupByKey.create()) |
| .apply("RemergeAgain", Window.remerge()); |
| |
| p.run(); |
| |
| Assert.assertTrue( |
| middle |
| .getWindowingStrategy() |
| .getWindowFn() |
| .isCompatible(Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| } |
| |
| @Test |
| public void testGroupByKeyDirectUnbounded() { |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| new PTransform<PBegin, PCollection<KV<String, Integer>>>() { |
| @Override |
| public PCollection<KV<String, Integer>> expand(PBegin input) { |
| return PCollection.createPrimitiveOutputInternal( |
| input.getPipeline(), |
| WindowingStrategy.globalDefault(), |
| PCollection.IsBounded.UNBOUNDED, |
| KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); |
| } |
| }); |
| |
| thrown.expect(IllegalStateException.class); |
| thrown.expectMessage( |
| "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without " |
| + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); |
| |
| input.apply("GroupByKey", GroupByKey.create()); |
| } |
| |
| /** |
| * Tests that when two elements are combined via a GroupByKey their output timestamp agrees with |
| * the windowing function customized to actually be the same as the default, the earlier of the |
| * two values. |
| */ |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testTimestampCombinerEarliest() { |
| |
| p.apply( |
| Create.timestamped( |
| TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), |
| TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) |
| .apply( |
| Window.<KV<Integer, String>>into(FixedWindows.of(Duration.standardMinutes(10))) |
| .withTimestampCombiner(TimestampCombiner.EARLIEST)) |
| .apply(GroupByKey.create()) |
| .apply(ParDo.of(new AssertTimestamp(new Instant(0)))); |
| |
| p.run(); |
| } |
| |
| /** |
| * Tests that when two elements are combined via a GroupByKey their output timestamp agrees with |
| * the windowing function customized to use the latest value. |
| */ |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testTimestampCombinerLatest() { |
| p.apply( |
| Create.timestamped( |
| TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), |
| TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) |
| .apply( |
| Window.<KV<Integer, String>>into(FixedWindows.of(Duration.standardMinutes(10))) |
| .withTimestampCombiner(TimestampCombiner.LATEST)) |
| .apply(GroupByKey.create()) |
| .apply(ParDo.of(new AssertTimestamp(new Instant(10)))); |
| |
| p.run(); |
| } |
| |
| @Test |
| public void testGroupByKeyGetName() { |
| Assert.assertEquals("GroupByKey", GroupByKey.<String, Integer>create().getName()); |
| } |
| |
| @Test |
| public void testDisplayData() { |
| GroupByKey<String, String> groupByKey = GroupByKey.create(); |
| GroupByKey<String, String> groupByFewKeys = GroupByKey.createWithFewKeys(); |
| |
| DisplayData gbkDisplayData = DisplayData.from(groupByKey); |
| DisplayData fewKeysDisplayData = DisplayData.from(groupByFewKeys); |
| |
| assertThat(gbkDisplayData.items(), empty()); |
| assertThat(fewKeysDisplayData, hasDisplayItem("fewKeys", true)); |
| } |
| |
| /** Verify that runners correctly hash/group on the encoded value and not the value itself. */ |
| @Test |
| @Category({ValidatesRunner.class, DataflowPortabilityApiUnsupported.class}) |
| public void testGroupByKeyWithBadEqualsHashCode() throws Exception { |
| final int numValues = 10; |
| final int numKeys = 5; |
| |
| p.getCoderRegistry() |
| .registerCoderProvider( |
| CoderProviders.fromStaticMethods(BadEqualityKey.class, DeterministicKeyCoder.class)); |
| |
| // construct input data |
| List<KV<BadEqualityKey, Long>> input = new ArrayList<>(); |
| for (int i = 0; i < numValues; i++) { |
| for (int key = 0; key < numKeys; key++) { |
| input.add(KV.of(new BadEqualityKey(key), 1L)); |
| } |
| } |
| |
| // We first ensure that the values are randomly partitioned in the beginning. |
| // Some runners might otherwise keep all values on the machine where |
| // they are initially created. |
| PCollection<KV<BadEqualityKey, Long>> dataset1 = |
| p.apply(Create.of(input)) |
| .apply(ParDo.of(new AssignRandomKey())) |
| .apply(Reshuffle.of()) |
| .apply(Values.create()); |
| |
| // Make the GroupByKey and Count implicit, in real-world code |
| // this would be a Count.perKey() |
| PCollection<KV<BadEqualityKey, Long>> result = |
| dataset1.apply(GroupByKey.create()).apply(Combine.groupedValues(new CountFn())); |
| |
| PAssert.that(result).satisfies(new AssertThatCountPerKeyCorrect(numValues)); |
| |
| PAssert.that(result.apply(Keys.create())).satisfies(new AssertThatAllKeysExist(numKeys)); |
| |
| p.run(); |
| } |
| |
| @Test |
| @Category({ValidatesRunner.class, LargeKeys.Above10KB.class}) |
| public void testLargeKeys10KB() throws Exception { |
| runLargeKeysTest(p, 10 << 10); |
| } |
| |
| @Test |
| @Category({ValidatesRunner.class, LargeKeys.Above100KB.class}) |
| public void testLargeKeys100KB() throws Exception { |
| runLargeKeysTest(p, 100 << 10); |
| } |
| |
| @Test |
| @Category({ValidatesRunner.class, LargeKeys.Above1MB.class}) |
| public void testLargeKeys1MB() throws Exception { |
| runLargeKeysTest(p, 1 << 20); |
| } |
| |
| @Test |
| @Category({ValidatesRunner.class, LargeKeys.Above10MB.class}) |
| public void testLargeKeys10MB() throws Exception { |
| runLargeKeysTest(p, 10 << 20); |
| } |
| |
| @Test |
| @Category({ValidatesRunner.class, LargeKeys.Above100MB.class}) |
| public void testLargeKeys100MB() throws Exception { |
| runLargeKeysTest(p, 100 << 20); |
| } |
| } |
| |
| /** Tests validating GroupByKey behaviors with windowing. */ |
| @RunWith(JUnit4.class) |
| public static class WindowTests extends SharedTestBase { |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testGroupByKeyAndWindows() { |
| List<KV<String, Integer>> ungroupedPairs = |
| Arrays.asList( |
| KV.of("k1", 3), // window [0, 5) |
| KV.of("k5", Integer.MAX_VALUE), // window [0, 5) |
| KV.of("k5", Integer.MIN_VALUE), // window [0, 5) |
| KV.of("k2", 66), // window [0, 5) |
| KV.of("k1", 4), // window [5, 10) |
| KV.of("k2", -33), // window [5, 10) |
| KV.of("k3", 0)); // window [5, 10) |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.timestamped(ungroupedPairs, Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L)) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); |
| PCollection<KV<String, Iterable<Integer>>> output = |
| input.apply(Window.into(FixedWindows.of(new Duration(5)))).apply(GroupByKey.create()); |
| |
| PAssert.that(output) |
| .satisfies( |
| containsKvs( |
| kv("k1", 3), |
| kv("k1", 4), |
| kv("k5", Integer.MAX_VALUE, Integer.MIN_VALUE), |
| kv("k2", 66), |
| kv("k2", -33), |
| kv("k3", 0))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(0L), Duration.millis(5L))) |
| .satisfies( |
| containsKvs( |
| kv("k1", 3), kv("k5", Integer.MIN_VALUE, Integer.MAX_VALUE), kv("k2", 66))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(5L), Duration.millis(5L))) |
| .satisfies(containsKvs(kv("k1", 4), kv("k2", -33), kv("k3", 0))); |
| |
| p.run(); |
| } |
| |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testGroupByKeyMultipleWindows() { |
| PCollection<KV<String, Integer>> windowedInput = |
| p.apply( |
| Create.timestamped( |
| TimestampedValue.of(KV.of("foo", 1), new Instant(1)), |
| TimestampedValue.of(KV.of("foo", 4), new Instant(4)), |
| TimestampedValue.of(KV.of("bar", 3), new Instant(3)))) |
| .apply( |
| Window.into(SlidingWindows.of(Duration.millis(5L)).every(Duration.millis(3L)))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = windowedInput.apply(GroupByKey.create()); |
| |
| PAssert.that(output) |
| .satisfies( |
| containsKvs(kv("foo", 1, 4), kv("foo", 1), kv("foo", 4), kv("bar", 3), kv("bar", 3))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(-3L), Duration.millis(5L))) |
| .satisfies(containsKvs(kv("foo", 1))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(0L), Duration.millis(5L))) |
| .satisfies(containsKvs(kv("foo", 1, 4), kv("bar", 3))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(3L), Duration.millis(5L))) |
| .satisfies(containsKvs(kv("foo", 4), kv("bar", 3))); |
| |
| p.run(); |
| } |
| |
| @Test |
| @Category(ValidatesRunner.class) |
| public void testGroupByKeyMergingWindows() { |
| PCollection<KV<String, Integer>> windowedInput = |
| p.apply( |
| Create.timestamped( |
| TimestampedValue.of(KV.of("foo", 1), new Instant(1)), |
| TimestampedValue.of(KV.of("foo", 4), new Instant(4)), |
| TimestampedValue.of(KV.of("bar", 3), new Instant(3)), |
| TimestampedValue.of(KV.of("foo", 9), new Instant(9)))) |
| .apply(Window.into(Sessions.withGapDuration(Duration.millis(4L)))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = windowedInput.apply(GroupByKey.create()); |
| |
| PAssert.that(output).satisfies(containsKvs(kv("foo", 1, 4), kv("foo", 9), kv("bar", 3))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(1L), new Instant(8L))) |
| .satisfies(containsKvs(kv("foo", 1, 4))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(3L), new Instant(7L))) |
| .satisfies(containsKvs(kv("bar", 3))); |
| PAssert.that(output) |
| .inWindow(new IntervalWindow(new Instant(9L), new Instant(13L))) |
| .satisfies(containsKvs(kv("foo", 9))); |
| |
| p.run(); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testIdentityWindowFnPropagation() { |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.create()); |
| |
| p.run(); |
| |
| Assert.assertTrue( |
| output |
| .getWindowingStrategy() |
| .getWindowFn() |
| .isCompatible(FixedWindows.of(Duration.standardMinutes(1)))); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testWindowFnInvalidation() { |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.into(Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.create()); |
| |
| p.run(); |
| |
| Assert.assertTrue( |
| output |
| .getWindowingStrategy() |
| .getWindowFn() |
| .isCompatible( |
| new InvalidWindows( |
| "Invalid", Sessions.withGapDuration(Duration.standardMinutes(1))))); |
| } |
| |
| @Test |
| public void testInvalidWindowsDirect() { |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.into(Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| |
| thrown.expect(IllegalStateException.class); |
| thrown.expectMessage("GroupByKey must have a valid Window merge function"); |
| input.apply("GroupByKey", GroupByKey.create()).apply("GroupByKeyAgain", GroupByKey.create()); |
| } |
| } |
| |
| private static KV<String, Collection<Integer>> kv(String key, Integer... values) { |
| return KV.of(key, ImmutableList.copyOf(values)); |
| } |
| |
| private static SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, Void> containsKvs( |
| KV<String, Collection<Integer>>... kvs) { |
| return new ContainsKVs(ImmutableList.copyOf(kvs)); |
| } |
| |
| /** |
| * A function that asserts that the input element contains the expected {@link KV KVs} in any |
| * order, where values appear in any order. |
| */ |
| private static class ContainsKVs |
| implements SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, Void> { |
| private final List<KV<String, Collection<Integer>>> expectedKvs; |
| |
| private ContainsKVs(List<KV<String, Collection<Integer>>> expectedKvs) { |
| this.expectedKvs = expectedKvs; |
| } |
| |
| @Override |
| public Void apply(Iterable<KV<String, Iterable<Integer>>> input) { |
| List<Matcher<? super KV<String, Iterable<Integer>>>> matchers = new ArrayList<>(); |
| for (KV<String, Collection<Integer>> expected : expectedKvs) { |
| Integer[] values = expected.getValue().toArray(new Integer[0]); |
| matchers.add(isKv(equalTo(expected.getKey()), containsInAnyOrder(values))); |
| } |
| assertThat(input, containsInAnyOrder(matchers.toArray(new Matcher[0]))); |
| return null; |
| } |
| } |
| |
| private static class AssertTimestamp<K, V> extends DoFn<KV<K, V>, Void> { |
| private final Instant timestamp; |
| |
| public AssertTimestamp(Instant timestamp) { |
| this.timestamp = timestamp; |
| } |
| |
| @ProcessElement |
| public void processElement(ProcessContext c) throws Exception { |
| assertThat(c.timestamp(), equalTo(timestamp)); |
| } |
| } |
| |
| private static String bigString(char c, int size) { |
| char[] buf = new char[size]; |
| for (int i = 0; i < size; i++) { |
| buf[i] = c; |
| } |
| return new String(buf); |
| } |
| |
| private static void runLargeKeysTest(TestPipeline p, final int keySize) throws Exception { |
| PCollection<KV<String, Integer>> result = |
| p.apply(Create.of("a", "a", "b")) |
| .apply( |
| "Expand", |
| ParDo.of( |
| new DoFn<String, KV<String, String>>() { |
| @ProcessElement |
| public void process(ProcessContext c) { |
| c.output(KV.of(bigString(c.element().charAt(0), keySize), c.element())); |
| } |
| })) |
| .apply(GroupByKey.create()) |
| .apply( |
| "Count", |
| ParDo.of( |
| new DoFn<KV<String, Iterable<String>>, KV<String, Integer>>() { |
| @ProcessElement |
| public void process(ProcessContext c) { |
| int size = 0; |
| for (String value : c.element().getValue()) { |
| size++; |
| } |
| c.output(KV.of(c.element().getKey(), size)); |
| } |
| })); |
| |
| PAssert.that(result) |
| .satisfies( |
| values -> { |
| assertThat( |
| values, |
| containsInAnyOrder( |
| KV.of(bigString('a', keySize), 2), KV.of(bigString('b', keySize), 1))); |
| return null; |
| }); |
| |
| p.run(); |
| } |
| |
| /** |
| * This is a bogus key class that returns random hash values from {@link #hashCode()} and always |
| * returns {@code false} for {@link #equals(Object)}. The results of the test are correct if the |
| * runner correctly hashes and sorts on the encoded bytes. |
| */ |
| static class BadEqualityKey { |
| long key; |
| |
| public BadEqualityKey() {} |
| |
| public BadEqualityKey(long key) { |
| this.key = key; |
| } |
| |
| @Override |
| public boolean equals(Object o) { |
| return false; |
| } |
| |
| @Override |
| public int hashCode() { |
| return ThreadLocalRandom.current().nextInt(); |
| } |
| } |
| |
| /** Deterministic {@link Coder} for {@link BadEqualityKey}. */ |
| static class DeterministicKeyCoder extends AtomicCoder<BadEqualityKey> { |
| |
| public static DeterministicKeyCoder of() { |
| return INSTANCE; |
| } |
| |
| ///////////////////////////////////////////////////////////////////////////// |
| |
| private static final DeterministicKeyCoder INSTANCE = new DeterministicKeyCoder(); |
| |
| private DeterministicKeyCoder() {} |
| |
| @Override |
| public void encode(BadEqualityKey value, OutputStream outStream) throws IOException { |
| new DataOutputStream(outStream).writeLong(value.key); |
| } |
| |
| @Override |
| public BadEqualityKey decode(InputStream inStream) throws IOException { |
| return new BadEqualityKey(new DataInputStream(inStream).readLong()); |
| } |
| |
| @Override |
| public void verifyDeterministic() {} |
| } |
| |
| /** Creates a KV that wraps the original KV together with a random key. */ |
| static class AssignRandomKey |
| extends DoFn<KV<BadEqualityKey, Long>, KV<Long, KV<BadEqualityKey, Long>>> { |
| |
| @ProcessElement |
| public void processElement(ProcessContext c) throws Exception { |
| c.output(KV.of(ThreadLocalRandom.current().nextLong(), c.element())); |
| } |
| } |
| |
| static class CountFn implements SerializableFunction<Iterable<Long>, Long> { |
| @Override |
| public Long apply(Iterable<Long> input) { |
| long result = 0L; |
| for (Long in : input) { |
| result += in; |
| } |
| return result; |
| } |
| } |
| |
| static class AssertThatCountPerKeyCorrect |
| implements SerializableFunction<Iterable<KV<BadEqualityKey, Long>>, Void> { |
| private final int numValues; |
| |
| AssertThatCountPerKeyCorrect(int numValues) { |
| this.numValues = numValues; |
| } |
| |
| @Override |
| public Void apply(Iterable<KV<BadEqualityKey, Long>> input) { |
| for (KV<BadEqualityKey, Long> val : input) { |
| Assert.assertEquals(numValues, (long) val.getValue()); |
| } |
| return null; |
| } |
| } |
| |
| static class AssertThatAllKeysExist |
| implements SerializableFunction<Iterable<BadEqualityKey>, Void> { |
| private final int numKeys; |
| |
| AssertThatAllKeysExist(int numKeys) { |
| this.numKeys = numKeys; |
| } |
| |
| private static <T> Iterable<Object> asStructural( |
| final Iterable<T> iterable, final Coder<T> coder) { |
| |
| return StreamSupport.stream(iterable.spliterator(), false) |
| .map( |
| input -> { |
| try { |
| return coder.structuralValue(input); |
| } catch (Exception e) { |
| Assert.fail("Could not structural values."); |
| throw new RuntimeException(); // to satisfy the compiler... |
| } |
| }) |
| .collect(Collectors.toList()); |
| } |
| |
| @Override |
| public Void apply(Iterable<BadEqualityKey> input) { |
| final DeterministicKeyCoder keyCoder = DeterministicKeyCoder.of(); |
| |
| List<BadEqualityKey> expectedList = new ArrayList<>(); |
| for (int key = 0; key < numKeys; key++) { |
| expectedList.add(new BadEqualityKey(key)); |
| } |
| |
| Iterable<Object> structuralInput = asStructural(input, keyCoder); |
| Iterable<Object> structuralExpected = asStructural(expectedList, keyCoder); |
| |
| for (Object expected : structuralExpected) { |
| assertThat(structuralInput, hasItem(expected)); |
| } |
| |
| return null; |
| } |
| } |
| } |