blob: 818716c45d43d62cd537c9b207d50e767e7c8635 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.direct;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.runners.core.KeyedWorkItems;
import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
import org.apache.beam.runners.local.StructuralKey;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.HashMultiset;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Multiset;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.joda.time.Instant;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link GroupByKeyEvaluatorFactory}. */
@RunWith(JUnit4.class)
public class GroupByKeyEvaluatorFactoryTest {
private BundleFactory bundleFactory = ImmutableListBundleFactory.create();
@Rule public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
@Test
public void testInMemoryEvaluator() throws Exception {
KV<String, Integer> firstFoo = KV.of("foo", -1);
KV<String, Integer> secondFoo = KV.of("foo", 1);
KV<String, Integer> thirdFoo = KV.of("foo", 3);
KV<String, Integer> firstBar = KV.of("bar", 22);
KV<String, Integer> secondBar = KV.of("bar", 12);
KV<String, Integer> firstBaz = KV.of("baz", Integer.MAX_VALUE);
PCollection<KV<String, Integer>> values =
p.apply(Create.of(firstFoo, firstBar, secondFoo, firstBaz, secondBar, thirdFoo));
PCollection<KeyedWorkItem<String, Integer>> groupedKvs =
values.apply(new DirectGroupByKeyOnly<>());
CommittedBundle<KV<String, Integer>> inputBundle =
bundleFactory.createBundle(values).commit(Instant.now());
EvaluationContext evaluationContext = mock(EvaluationContext.class);
StructuralKey<String> fooKey = StructuralKey.of("foo", StringUtf8Coder.of());
UncommittedBundle<KeyedWorkItem<String, Integer>> fooBundle =
bundleFactory.createKeyedBundle(fooKey, groupedKvs);
StructuralKey<String> barKey = StructuralKey.of("bar", StringUtf8Coder.of());
UncommittedBundle<KeyedWorkItem<String, Integer>> barBundle =
bundleFactory.createKeyedBundle(barKey, groupedKvs);
StructuralKey<String> bazKey = StructuralKey.of("baz", StringUtf8Coder.of());
UncommittedBundle<KeyedWorkItem<String, Integer>> bazBundle =
bundleFactory.createKeyedBundle(bazKey, groupedKvs);
when(evaluationContext.createKeyedBundle(fooKey, groupedKvs)).thenReturn(fooBundle);
when(evaluationContext.createKeyedBundle(barKey, groupedKvs)).thenReturn(barBundle);
when(evaluationContext.createKeyedBundle(bazKey, groupedKvs)).thenReturn(bazBundle);
// The input to a GroupByKey is assumed to be a KvCoder
@SuppressWarnings("unchecked")
Coder<String> keyCoder = ((KvCoder<String, Integer>) values.getCoder()).getKeyCoder();
TransformEvaluator<KV<String, Integer>> evaluator =
new GroupByKeyOnlyEvaluatorFactory(evaluationContext)
.forApplication(DirectGraphs.getProducer(groupedKvs), inputBundle);
evaluator.processElement(WindowedValue.valueInGlobalWindow(firstFoo));
evaluator.processElement(WindowedValue.valueInGlobalWindow(secondFoo));
evaluator.processElement(WindowedValue.valueInGlobalWindow(thirdFoo));
evaluator.processElement(WindowedValue.valueInGlobalWindow(firstBar));
evaluator.processElement(WindowedValue.valueInGlobalWindow(secondBar));
evaluator.processElement(WindowedValue.valueInGlobalWindow(firstBaz));
evaluator.finishBundle();
assertThat(
fooBundle.commit(Instant.now()).getElements(),
contains(
new KeyedWorkItemMatcher<>(
KeyedWorkItems.elementsWorkItem(
"foo",
ImmutableSet.of(
WindowedValue.valueInGlobalWindow(-1),
WindowedValue.valueInGlobalWindow(1),
WindowedValue.valueInGlobalWindow(3))),
keyCoder)));
assertThat(
barBundle.commit(Instant.now()).getElements(),
contains(
new KeyedWorkItemMatcher<>(
KeyedWorkItems.elementsWorkItem(
"bar",
ImmutableSet.of(
WindowedValue.valueInGlobalWindow(12),
WindowedValue.valueInGlobalWindow(22))),
keyCoder)));
assertThat(
bazBundle.commit(Instant.now()).getElements(),
contains(
new KeyedWorkItemMatcher<>(
KeyedWorkItems.elementsWorkItem(
"baz", ImmutableSet.of(WindowedValue.valueInGlobalWindow(Integer.MAX_VALUE))),
keyCoder)));
}
private static class KeyedWorkItemMatcher<K, V>
extends BaseMatcher<WindowedValue<KeyedWorkItem<K, V>>> {
private final KeyedWorkItem<K, V> myWorkItem;
private final Coder<K> keyCoder;
KeyedWorkItemMatcher(KeyedWorkItem<K, V> myWorkItem, Coder<K> keyCoder) {
this.myWorkItem = myWorkItem;
this.keyCoder = keyCoder;
}
@Override
public boolean matches(Object item) {
if (item == null || !(item instanceof WindowedValue)) {
return false;
}
WindowedValue<KeyedWorkItem<K, V>> that = (WindowedValue<KeyedWorkItem<K, V>>) item;
Multiset<WindowedValue<V>> myValues = HashMultiset.create();
Multiset<WindowedValue<V>> thatValues = HashMultiset.create();
for (WindowedValue<V> value : myWorkItem.elementsIterable()) {
myValues.add(value);
}
for (WindowedValue<V> value : that.getValue().elementsIterable()) {
thatValues.add(value);
}
try {
return myValues.equals(thatValues)
&& keyCoder
.structuralValue(myWorkItem.key())
.equals(keyCoder.structuralValue(that.getValue().key()));
} catch (Exception e) {
return false;
}
}
@Override
public void describeTo(Description description) {
description
.appendText("KeyedWorkItem<K, V> containing key ")
.appendValue(myWorkItem.key())
.appendText(" and values ")
.appendValueList("[", ", ", "]", myWorkItem.elementsIterable());
}
}
}