| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.transforms; |
| |
| import static org.apache.beam.sdk.TestUtils.KvMatcher.isKv; |
| import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; |
| import static org.hamcrest.CoreMatchers.equalTo; |
| import static org.hamcrest.CoreMatchers.hasItem; |
| import static org.hamcrest.Matchers.empty; |
| import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; |
| import static org.hamcrest.core.Is.is; |
| import static org.junit.Assert.assertThat; |
| |
| import org.apache.beam.sdk.Pipeline; |
| import org.apache.beam.sdk.coders.AtomicCoder; |
| import org.apache.beam.sdk.coders.BigEndianIntegerCoder; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.coders.MapCoder; |
| import org.apache.beam.sdk.coders.StringUtf8Coder; |
| import org.apache.beam.sdk.testing.NeedsRunner; |
| import org.apache.beam.sdk.testing.PAssert; |
| import org.apache.beam.sdk.testing.RunnableOnService; |
| import org.apache.beam.sdk.testing.TestPipeline; |
| import org.apache.beam.sdk.transforms.display.DisplayData; |
| import org.apache.beam.sdk.transforms.windowing.FixedWindows; |
| import org.apache.beam.sdk.transforms.windowing.InvalidWindows; |
| import org.apache.beam.sdk.transforms.windowing.OutputTimeFns; |
| import org.apache.beam.sdk.transforms.windowing.Sessions; |
| import org.apache.beam.sdk.transforms.windowing.Window; |
| import org.apache.beam.sdk.util.Reshuffle; |
| import org.apache.beam.sdk.util.WindowingStrategy; |
| import org.apache.beam.sdk.values.KV; |
| import org.apache.beam.sdk.values.PBegin; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.TimestampedValue; |
| import org.apache.beam.sdk.values.TypeDescriptor; |
| |
| import com.google.common.base.Function; |
| import com.google.common.collect.Iterables; |
| |
| import com.fasterxml.jackson.annotation.JsonCreator; |
| |
| import org.joda.time.Duration; |
| import org.joda.time.Instant; |
| import org.junit.Assert; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.experimental.categories.Category; |
| import org.junit.rules.ExpectedException; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.concurrent.ThreadLocalRandom; |
| |
| /** |
| * Tests for GroupByKey. |
| */ |
| @RunWith(JUnit4.class) |
| @SuppressWarnings({"rawtypes", "unchecked"}) |
| public class GroupByKeyTest { |
| |
| @Rule |
| public ExpectedException thrown = ExpectedException.none(); |
| |
| @Test |
| @Category(RunnableOnService.class) |
| public void testGroupByKey() { |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList( |
| KV.of("k1", 3), |
| KV.of("k5", Integer.MAX_VALUE), |
| KV.of("k5", Integer.MIN_VALUE), |
| KV.of("k2", 66), |
| KV.of("k1", 4), |
| KV.of("k2", -33), |
| KV.of("k3", 0)); |
| |
| Pipeline p = TestPipeline.create(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = |
| input.apply(GroupByKey.<String, Integer>create()); |
| |
| PAssert.that(output) |
| .satisfies(new AssertThatHasExpectedContentsForTestGroupByKey()); |
| |
| p.run(); |
| } |
| |
| static class AssertThatHasExpectedContentsForTestGroupByKey |
| implements SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, |
| Void> { |
| @Override |
| public Void apply(Iterable<KV<String, Iterable<Integer>>> actual) { |
| assertThat(actual, containsInAnyOrder( |
| isKv(is("k1"), containsInAnyOrder(3, 4)), |
| isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, |
| Integer.MIN_VALUE)), |
| isKv(is("k2"), containsInAnyOrder(66, -33)), |
| isKv(is("k3"), containsInAnyOrder(0)))); |
| return null; |
| } |
| } |
| |
| @Test |
| @Category(RunnableOnService.class) |
| public void testGroupByKeyAndWindows() { |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList( |
| KV.of("k1", 3), // window [0, 5) |
| KV.of("k5", Integer.MAX_VALUE), // window [0, 5) |
| KV.of("k5", Integer.MIN_VALUE), // window [0, 5) |
| KV.of("k2", 66), // window [0, 5) |
| KV.of("k1", 4), // window [5, 10) |
| KV.of("k2", -33), // window [5, 10) |
| KV.of("k3", 0)); // window [5, 10) |
| |
| Pipeline p = TestPipeline.create(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.timestamped(ungroupedPairs, Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L)) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); |
| PCollection<KV<String, Iterable<Integer>>> output = |
| input.apply(Window.<KV<String, Integer>>into(FixedWindows.of(new Duration(5)))) |
| .apply(GroupByKey.<String, Integer>create()); |
| |
| PAssert.that(output) |
| .satisfies(new AssertThatHasExpectedContentsForTestGroupByKeyAndWindows()); |
| |
| p.run(); |
| } |
| |
| static class AssertThatHasExpectedContentsForTestGroupByKeyAndWindows |
| implements SerializableFunction<Iterable<KV<String, Iterable<Integer>>>, |
| Void> { |
| @Override |
| public Void apply(Iterable<KV<String, Iterable<Integer>>> actual) { |
| assertThat(actual, containsInAnyOrder( |
| isKv(is("k1"), containsInAnyOrder(3)), |
| isKv(is("k1"), containsInAnyOrder(4)), |
| isKv(is("k5"), containsInAnyOrder(Integer.MAX_VALUE, |
| Integer.MIN_VALUE)), |
| isKv(is("k2"), containsInAnyOrder(66)), |
| isKv(is("k2"), containsInAnyOrder(-33)), |
| isKv(is("k3"), containsInAnyOrder(0)))); |
| return null; |
| } |
| } |
| |
| @Test |
| @Category(RunnableOnService.class) |
| public void testGroupByKeyEmpty() { |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| Pipeline p = TestPipeline.create(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = |
| input.apply(GroupByKey.<String, Integer>create()); |
| |
| PAssert.that(output).empty(); |
| |
| p.run(); |
| } |
| |
| @Test |
| public void testGroupByKeyNonDeterministic() throws Exception { |
| |
| List<KV<Map<String, String>, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| Pipeline p = TestPipeline.create(); |
| |
| PCollection<KV<Map<String, String>, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder( |
| KvCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), |
| BigEndianIntegerCoder.of()))); |
| |
| thrown.expect(IllegalStateException.class); |
| thrown.expectMessage("must be deterministic"); |
| input.apply(GroupByKey.<Map<String, String>, Integer>create()); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testIdentityWindowFnPropagation() { |
| Pipeline p = TestPipeline.create(); |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.standardMinutes(1)))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = |
| input.apply(GroupByKey.<String, Integer>create()); |
| |
| p.run(); |
| |
| Assert.assertTrue(output.getWindowingStrategy().getWindowFn().isCompatible( |
| FixedWindows.of(Duration.standardMinutes(1)))); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testWindowFnInvalidation() { |
| Pipeline p = TestPipeline.create(); |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.<KV<String, Integer>>into( |
| Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| |
| PCollection<KV<String, Iterable<Integer>>> output = |
| input.apply(GroupByKey.<String, Integer>create()); |
| |
| p.run(); |
| |
| Assert.assertTrue( |
| output.getWindowingStrategy().getWindowFn().isCompatible( |
| new InvalidWindows( |
| "Invalid", |
| Sessions.withGapDuration( |
| Duration.standardMinutes(1))))); |
| } |
| |
| @Test |
| public void testInvalidWindowsDirect() { |
| Pipeline p = TestPipeline.create(); |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.<KV<String, Integer>>into( |
| Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| |
| thrown.expect(IllegalStateException.class); |
| thrown.expectMessage("GroupByKey must have a valid Window merge function"); |
| input |
| .apply("GroupByKey", GroupByKey.<String, Integer>create()) |
| .apply("GroupByKeyAgain", GroupByKey.<String, Iterable<Integer>>create()); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testRemerge() { |
| Pipeline p = TestPipeline.create(); |
| |
| List<KV<String, Integer>> ungroupedPairs = Arrays.asList(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply(Create.of(ungroupedPairs) |
| .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) |
| .apply(Window.<KV<String, Integer>>into( |
| Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| |
| PCollection<KV<String, Iterable<Iterable<Integer>>>> middle = input |
| .apply("GroupByKey", GroupByKey.<String, Integer>create()) |
| .apply("Remerge", Window.<KV<String, Iterable<Integer>>>remerge()) |
| .apply("GroupByKeyAgain", GroupByKey.<String, Iterable<Integer>>create()) |
| .apply("RemergeAgain", Window.<KV<String, Iterable<Iterable<Integer>>>>remerge()); |
| |
| p.run(); |
| |
| Assert.assertTrue( |
| middle.getWindowingStrategy().getWindowFn().isCompatible( |
| Sessions.withGapDuration(Duration.standardMinutes(1)))); |
| } |
| |
| @Test |
| public void testGroupByKeyDirectUnbounded() { |
| Pipeline p = TestPipeline.create(); |
| |
| PCollection<KV<String, Integer>> input = |
| p.apply( |
| new PTransform<PBegin, PCollection<KV<String, Integer>>>() { |
| @Override |
| public PCollection<KV<String, Integer>> apply(PBegin input) { |
| return PCollection.<KV<String, Integer>>createPrimitiveOutputInternal( |
| input.getPipeline(), |
| WindowingStrategy.globalDefault(), |
| PCollection.IsBounded.UNBOUNDED) |
| .setTypeDescriptorInternal(new TypeDescriptor<KV<String, Integer>>() {}); |
| } |
| }); |
| |
| thrown.expect(IllegalStateException.class); |
| thrown.expectMessage( |
| "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without " |
| + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); |
| |
| input.apply("GroupByKey", GroupByKey.<String, Integer>create()); |
| } |
| |
| /** |
| * Tests that when two elements are combined via a GroupByKey their output timestamp agrees |
| * with the windowing function customized to actually be the same as the default, the earlier of |
| * the two values. |
| */ |
| @Test |
| @Category(RunnableOnService.class) |
| public void testOutputTimeFnEarliest() { |
| Pipeline pipeline = TestPipeline.create(); |
| |
| pipeline.apply( |
| Create.timestamped( |
| TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), |
| TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) |
| .apply(Window.<KV<Integer, String>>into(FixedWindows.of(Duration.standardMinutes(10))) |
| .withOutputTimeFn(OutputTimeFns.outputAtEarliestInputTimestamp())) |
| .apply(GroupByKey.<Integer, String>create()) |
| .apply(ParDo.of(new AssertTimestamp(new Instant(0)))); |
| |
| pipeline.run(); |
| } |
| |
| |
| /** |
| * Tests that when two elements are combined via a GroupByKey their output timestamp agrees |
| * with the windowing function customized to use the latest value. |
| */ |
| @Test |
| @Category(RunnableOnService.class) |
| public void testOutputTimeFnLatest() { |
| Pipeline pipeline = TestPipeline.create(); |
| |
| pipeline.apply( |
| Create.timestamped( |
| TimestampedValue.of(KV.of(0, "hello"), new Instant(0)), |
| TimestampedValue.of(KV.of(0, "goodbye"), new Instant(10)))) |
| .apply(Window.<KV<Integer, String>>into(FixedWindows.of(Duration.standardMinutes(10))) |
| .withOutputTimeFn(OutputTimeFns.outputAtLatestInputTimestamp())) |
| .apply(GroupByKey.<Integer, String>create()) |
| .apply(ParDo.of(new AssertTimestamp(new Instant(10)))); |
| |
| pipeline.run(); |
| } |
| |
| private static class AssertTimestamp<K, V> extends DoFn<KV<K, V>, Void> { |
| private final Instant timestamp; |
| |
| public AssertTimestamp(Instant timestamp) { |
| this.timestamp = timestamp; |
| } |
| |
| @Override |
| public void processElement(ProcessContext c) throws Exception { |
| assertThat(c.timestamp(), equalTo(timestamp)); |
| } |
| } |
| |
| @Test |
| public void testGroupByKeyGetName() { |
| Assert.assertEquals("GroupByKey", GroupByKey.<String, Integer>create().getName()); |
| } |
| |
| @Test |
| public void testDisplayData() { |
| GroupByKey<String, String> groupByKey = GroupByKey.create(); |
| GroupByKey<String, String> groupByFewKeys = GroupByKey.create(true); |
| |
| DisplayData gbkDisplayData = DisplayData.from(groupByKey); |
| DisplayData fewKeysDisplayData = DisplayData.from(groupByFewKeys); |
| |
| assertThat(gbkDisplayData.items(), empty()); |
| assertThat(fewKeysDisplayData, hasDisplayItem("fewKeys", true)); |
| } |
| |
| |
| /** |
| * Verify that runners correctly hash/group on the encoded value |
| * and not the value itself. |
| */ |
| @Test |
| @Category(RunnableOnService.class) |
| public void testGroupByKeyWithBadEqualsHashCode() throws Exception { |
| final int numValues = 10; |
| final int numKeys = 5; |
| |
| Pipeline p = TestPipeline.create(); |
| |
| p.getCoderRegistry().registerCoder(BadEqualityKey.class, DeterministicKeyCoder.class); |
| |
| // construct input data |
| List<KV<BadEqualityKey, Long>> input = new ArrayList<>(); |
| for (int i = 0; i < numValues; i++) { |
| for (int key = 0; key < numKeys; key++) { |
| input.add(KV.of(new BadEqualityKey(key), 1L)); |
| } |
| } |
| |
| // We first ensure that the values are randomly partitioned in the beginning. |
| // Some runners might otherwise keep all values on the machine where |
| // they are initially created. |
| PCollection<KV<BadEqualityKey, Long>> dataset1 = p |
| .apply(Create.of(input)) |
| .apply(ParDo.of(new AssignRandomKey())) |
| .apply(Reshuffle.<Long, KV<BadEqualityKey, Long>>of()) |
| .apply(Values.<KV<BadEqualityKey, Long>>create()); |
| |
| // Make the GroupByKey and Count implicit, in real-world code |
| // this would be a Count.perKey() |
| PCollection<KV<BadEqualityKey, Long>> result = dataset1 |
| .apply(GroupByKey.<BadEqualityKey, Long>create()) |
| .apply(Combine.<BadEqualityKey, Long>groupedValues(new CountFn())); |
| |
| PAssert.that(result).satisfies(new AssertThatCountPerKeyCorrect(numValues)); |
| |
| PAssert.that(result.apply(Keys.<BadEqualityKey>create())) |
| .satisfies(new AssertThatAllKeysExist(numKeys)); |
| |
| p.run(); |
| } |
| |
| /** |
| * This is a bogus key class that returns random hash values from {@link #hashCode()} and always |
| * returns {@code false} for {@link #equals(Object)}. The results of the test are correct if |
| * the runner correctly hashes and sorts on the encoded bytes. |
| */ |
| static class BadEqualityKey { |
| long key; |
| |
| public BadEqualityKey() {} |
| |
| public BadEqualityKey(long key) { |
| this.key = key; |
| } |
| |
| @Override |
| public boolean equals(Object o) { |
| return false; |
| } |
| |
| @Override |
| public int hashCode() { |
| return ThreadLocalRandom.current().nextInt(); |
| } |
| } |
| |
| /** |
| * Deterministic {@link Coder} for {@link BadEqualityKey}. |
| */ |
| static class DeterministicKeyCoder extends AtomicCoder<BadEqualityKey> { |
| |
| @JsonCreator |
| public static DeterministicKeyCoder of() { |
| return INSTANCE; |
| } |
| |
| ///////////////////////////////////////////////////////////////////////////// |
| |
| private static final DeterministicKeyCoder INSTANCE = |
| new DeterministicKeyCoder(); |
| |
| private DeterministicKeyCoder() {} |
| |
| @Override |
| public void encode(BadEqualityKey value, OutputStream outStream, Context context) |
| throws IOException { |
| new DataOutputStream(outStream).writeLong(value.key); |
| } |
| |
| @Override |
| public BadEqualityKey decode(InputStream inStream, Context context) |
| throws IOException { |
| return new BadEqualityKey(new DataInputStream(inStream).readLong()); |
| } |
| } |
| |
| /** |
| * Creates a KV that wraps the original KV together with a random key. |
| */ |
| static class AssignRandomKey |
| extends DoFn<KV<BadEqualityKey, Long>, KV<Long, KV<BadEqualityKey, Long>>> { |
| |
| @Override |
| public void processElement(ProcessContext c) throws Exception { |
| c.output(KV.of(ThreadLocalRandom.current().nextLong(), c.element())); |
| } |
| } |
| |
| static class CountFn implements SerializableFunction<Iterable<Long>, Long> { |
| @Override |
| public Long apply(Iterable<Long> input) { |
| long result = 0L; |
| for (Long in: input) { |
| result += in; |
| } |
| return result; |
| } |
| } |
| |
| static class AssertThatCountPerKeyCorrect |
| implements SerializableFunction<Iterable<KV<BadEqualityKey, Long>>, Void> { |
| private final int numValues; |
| |
| AssertThatCountPerKeyCorrect(int numValues) { |
| this.numValues = numValues; |
| } |
| |
| @Override |
| public Void apply(Iterable<KV<BadEqualityKey, Long>> input) { |
| for (KV<BadEqualityKey, Long> val: input) { |
| Assert.assertEquals(numValues, (long) val.getValue()); |
| } |
| return null; |
| } |
| } |
| |
| static class AssertThatAllKeysExist |
| implements SerializableFunction<Iterable<BadEqualityKey>, Void> { |
| private final int numKeys; |
| |
| AssertThatAllKeysExist(int numKeys) { |
| this.numKeys = numKeys; |
| } |
| |
| private static <T> Iterable<Object> asStructural( |
| final Iterable<T> iterable, |
| final Coder<T> coder) { |
| |
| return Iterables.transform( |
| iterable, |
| new Function<T, Object>() { |
| @Override |
| public Object apply(T input) { |
| try { |
| return coder.structuralValue(input); |
| } catch (Exception e) { |
| Assert.fail("Could not structural values."); |
| throw new RuntimeException(); // to satisfy the compiler... |
| } |
| } |
| }); |
| |
| } |
| @Override |
| public Void apply(Iterable<BadEqualityKey> input) { |
| final DeterministicKeyCoder keyCoder = DeterministicKeyCoder.of(); |
| |
| List<BadEqualityKey> expectedList = new ArrayList<>(); |
| for (int key = 0; key < numKeys; key++) { |
| expectedList.add(new BadEqualityKey(key)); |
| } |
| |
| Iterable<Object> structuralInput = asStructural(input, keyCoder); |
| Iterable<Object> structuralExpected = asStructural(expectedList, keyCoder); |
| |
| for (Object expected: structuralExpected) { |
| assertThat(structuralInput, hasItem(expected)); |
| } |
| |
| return null; |
| } |
| } |
| } |