blob: 1786b9e41dd2c7f7ef7e484b53c0e3485378a4ac [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor;
import static org.junit.Assert.assertThat;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.DataflowPortabilityApiUnsupported;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesSideInputs;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Combine.BinaryCombineFn;
import org.apache.beam.sdk.transforms.CombineFns.CoCombineResult;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Objects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.hamcrest.Matchers;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link CombineFns}. */
@RunWith(JUnit4.class)
public class CombineFnsTest {
@Rule public final TestPipeline p = TestPipeline.create();
@Rule public ExpectedException expectedException = ExpectedException.none();
@Test
public void testDuplicatedTags() {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("it is already present in the composition");
TupleTag<Integer> tag = new TupleTag<>();
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), tag)
.with(new GetIntegerFunction(), Min.ofIntegers(), tag);
}
@Test
public void testDuplicatedTagsKeyed() {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("it is already present in the composition");
TupleTag<Integer> tag = new TupleTag<>();
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), tag)
.with(new GetIntegerFunction(), Min.ofIntegers(), tag);
}
@Test
public void testDuplicatedTagsWithContext() {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("it is already present in the composition");
TupleTag<UserString> tag = new TupleTag<>();
CombineFns.compose()
.with(new GetUserStringFunction(), new ConcatStringWithContext(null /* view */), tag)
.with(new GetUserStringFunction(), new ConcatStringWithContext(null /* view */), tag);
}
@Test
@Category({ValidatesRunner.class, UsesSideInputs.class, DataflowPortabilityApiUnsupported.class})
public void testComposedCombine() {
p.getCoderRegistry().registerCoderForClass(UserString.class, UserStringCoder.of());
PCollection<KV<String, KV<Integer, UserString>>> perKeyInput =
p.apply(
Create.timestamped(
Arrays.asList(
KV.of("a", KV.of(1, UserString.of("1"))),
KV.of("a", KV.of(1, UserString.of("1"))),
KV.of("a", KV.of(4, UserString.of("4"))),
KV.of("b", KV.of(1, UserString.of("1"))),
KV.of("b", KV.of(13, UserString.of("13")))),
Arrays.asList(0L, 4L, 7L, 10L, 16L))
.withCoder(
KvCoder.of(
StringUtf8Coder.of(),
KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of()))));
TupleTag<Integer> maxIntTag = new TupleTag<>();
TupleTag<UserString> concatStringTag = new TupleTag<>();
PCollection<KV<String, KV<Integer, String>>> combineGlobally =
perKeyInput
.apply(Values.create())
.apply(
Combine.globally(
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag)
.with(new GetUserStringFunction(), new ConcatString(), concatStringTag)))
.apply(WithKeys.of("global"))
.apply(
"ExtractGloballyResult",
ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
PCollection<KV<String, KV<Integer, String>>> combinePerKey =
perKeyInput
.apply(
Combine.perKey(
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag)
.with(new GetUserStringFunction(), new ConcatString(), concatStringTag)))
.apply(
"ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
PAssert.that(combineGlobally).containsInAnyOrder(KV.of("global", KV.of(13, "111134")));
PAssert.that(combinePerKey)
.containsInAnyOrder(KV.of("a", KV.of(4, "114")), KV.of("b", KV.of(13, "113")));
p.run();
}
@Test
@Category({ValidatesRunner.class, UsesSideInputs.class, DataflowPortabilityApiUnsupported.class})
public void testComposedCombineWithContext() {
p.getCoderRegistry().registerCoderForClass(UserString.class, UserStringCoder.of());
PCollectionView<String> view = p.apply(Create.of("I")).apply(View.asSingleton());
PCollection<KV<String, KV<Integer, UserString>>> perKeyInput =
p.apply(
Create.timestamped(
Arrays.asList(
KV.of("a", KV.of(1, UserString.of("1"))),
KV.of("a", KV.of(1, UserString.of("1"))),
KV.of("a", KV.of(4, UserString.of("4"))),
KV.of("b", KV.of(1, UserString.of("1"))),
KV.of("b", KV.of(13, UserString.of("13")))),
Arrays.asList(0L, 4L, 7L, 10L, 16L))
.withCoder(
KvCoder.of(
StringUtf8Coder.of(),
KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of()))));
TupleTag<Integer> maxIntTag = new TupleTag<>();
TupleTag<UserString> concatStringTag = new TupleTag<>();
PCollection<KV<String, KV<Integer, String>>> combineGlobally =
perKeyInput
.apply(Values.create())
.apply(
Combine.globally(
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag)
.with(
new GetUserStringFunction(),
new ConcatStringWithContext(view),
concatStringTag))
.withoutDefaults()
.withSideInputs(ImmutableList.of(view)))
.apply(WithKeys.of("global"))
.apply(
"ExtractGloballyResult",
ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
PCollection<KV<String, KV<Integer, String>>> combinePerKey =
perKeyInput
.apply(
Combine.<String, KV<Integer, UserString>, CoCombineResult>perKey(
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag)
.with(
new GetUserStringFunction(),
new ConcatStringWithContext(view),
concatStringTag))
.withSideInputs(ImmutableList.of(view)))
.apply(
"ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
PAssert.that(combineGlobally).containsInAnyOrder(KV.of("global", KV.of(13, "111134I")));
PAssert.that(combinePerKey)
.containsInAnyOrder(KV.of("a", KV.of(4, "114I")), KV.of("b", KV.of(13, "113I")));
p.run();
}
@Test
@Category({ValidatesRunner.class, DataflowPortabilityApiUnsupported.class})
public void testComposedCombineNullValues() {
p.getCoderRegistry()
.registerCoderForClass(UserString.class, NullableCoder.of(UserStringCoder.of()));
p.getCoderRegistry()
.registerCoderForClass(String.class, NullableCoder.of(StringUtf8Coder.of()));
PCollection<KV<String, KV<Integer, UserString>>> perKeyInput =
p.apply(
Create.timestamped(
Arrays.asList(
KV.of("a", KV.of(1, UserString.of("1"))),
KV.of("a", KV.of(1, UserString.of("1"))),
KV.of("a", KV.of(4, UserString.of("4"))),
KV.of("b", KV.of(1, UserString.of("1"))),
KV.of("b", KV.of(13, UserString.of("13")))),
Arrays.asList(0L, 4L, 7L, 10L, 16L))
.withCoder(
KvCoder.of(
NullableCoder.of(StringUtf8Coder.of()),
KvCoder.of(
BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of())))));
TupleTag<Integer> maxIntTag = new TupleTag<>();
TupleTag<UserString> concatStringTag = new TupleTag<>();
PCollection<KV<String, KV<Integer, String>>> combinePerKey =
perKeyInput
.apply(
Combine.perKey(
CombineFns.compose()
.with(new GetIntegerFunction(), Max.ofIntegers(), maxIntTag)
.with(
new GetUserStringFunction(), new OutputNullString(), concatStringTag)))
.apply(
"ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
PAssert.that(combinePerKey)
.containsInAnyOrder(
KV.of("a", KV.of(4, (String) null)), KV.of("b", KV.of(13, (String) null)));
p.run();
}
@Test
public void testComposedCombineDisplayData() {
SimpleFunction<String, String> extractFn =
new SimpleFunction<String, String>() {
@Override
public String apply(String input) {
return input;
}
};
DisplayDataCombineFn combineFn1 = new DisplayDataCombineFn("value1");
DisplayDataCombineFn combineFn2 = new DisplayDataCombineFn("value2");
CombineFns.ComposedCombineFn<String> composedCombine =
CombineFns.compose()
.with(extractFn, combineFn1, new TupleTag<>())
.with(extractFn, combineFn2, new TupleTag<>());
DisplayData displayData = DisplayData.from(composedCombine);
assertThat(displayData, hasDisplayItem("combineFn1", combineFn1.getClass()));
assertThat(displayData, hasDisplayItem("combineFn2", combineFn2.getClass()));
assertThat(displayData, includesDisplayDataFor("combineFn1", combineFn1));
assertThat(displayData, includesDisplayDataFor("combineFn2", combineFn2));
}
private static class DisplayDataCombineFn extends Combine.CombineFn<String, String, String> {
private final String value;
private static int i;
private final int id;
DisplayDataCombineFn(String value) {
id = ++i;
this.value = value;
}
@Override
public String createAccumulator() {
return null;
}
@Override
public String addInput(String accumulator, String input) {
return null;
}
@Override
public String mergeAccumulators(Iterable<String> accumulators) {
return null;
}
@Override
public String extractOutput(String accumulator) {
return null;
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
builder
.add(DisplayData.item("uniqueKey" + id, value))
.add(DisplayData.item("sharedKey", value));
}
}
private static class UserString implements Serializable {
private String strValue;
static UserString of(String strValue) {
UserString ret = new UserString();
ret.strValue = strValue;
return ret;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
UserString that = (UserString) o;
return Objects.equal(strValue, that.strValue);
}
@Override
public int hashCode() {
return Objects.hashCode(strValue);
}
}
private static class UserStringCoder extends AtomicCoder<UserString> {
public static UserStringCoder of() {
return INSTANCE;
}
private static final UserStringCoder INSTANCE = new UserStringCoder();
@Override
public void encode(UserString value, OutputStream outStream)
throws CoderException, IOException {
encode(value, outStream, Context.NESTED);
}
@Override
public void encode(UserString value, OutputStream outStream, Context context)
throws CoderException, IOException {
StringUtf8Coder.of().encode(value.strValue, outStream, context);
}
@Override
public UserString decode(InputStream inStream) throws CoderException, IOException {
return decode(inStream, Context.NESTED);
}
@Override
public UserString decode(InputStream inStream, Context context)
throws CoderException, IOException {
return UserString.of(StringUtf8Coder.of().decode(inStream, context));
}
@Override
public List<? extends Coder<?>> getCoderArguments() {
return Collections.emptyList();
}
@Override
public void verifyDeterministic() throws NonDeterministicException {}
}
private static class GetIntegerFunction extends SimpleFunction<KV<Integer, UserString>, Integer> {
@Override
public Integer apply(KV<Integer, UserString> input) {
return input.getKey();
}
}
private static class GetUserStringFunction
extends SimpleFunction<KV<Integer, UserString>, UserString> {
@Override
public UserString apply(KV<Integer, UserString> input) {
return input.getValue();
}
}
private static class ConcatString extends BinaryCombineFn<UserString> {
@Override
public UserString apply(UserString left, UserString right) {
String retStr = left.strValue + right.strValue;
char[] chars = retStr.toCharArray();
Arrays.sort(chars);
return UserString.of(new String(chars));
}
}
private static class OutputNullString extends BinaryCombineFn<UserString> {
@Override
public UserString apply(UserString left, UserString right) {
return null;
}
}
private static class ConcatStringWithContext
extends CombineFnWithContext<UserString, UserString, UserString> {
private final PCollectionView<String> view;
private ConcatStringWithContext(PCollectionView<String> view) {
this.view = view;
}
@Override
public UserString createAccumulator(CombineWithContext.Context c) {
return UserString.of(c.sideInput(view));
}
@Override
public UserString addInput(
UserString accumulator, UserString input, CombineWithContext.Context c) {
assertThat(accumulator.strValue, Matchers.startsWith(c.sideInput(view)));
accumulator.strValue += input.strValue;
return accumulator;
}
@Override
public UserString mergeAccumulators(
Iterable<UserString> accumulators, CombineWithContext.Context c) {
String keyPrefix = c.sideInput(view);
String all = keyPrefix;
for (UserString accumulator : accumulators) {
assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix));
all += accumulator.strValue.substring(keyPrefix.length());
accumulator.strValue = "cleared in mergeAccumulators";
}
return UserString.of(all);
}
@Override
public UserString extractOutput(UserString accumulator, CombineWithContext.Context c) {
assertThat(accumulator.strValue, Matchers.startsWith(c.sideInput(view)));
char[] chars = accumulator.strValue.toCharArray();
Arrays.sort(chars);
return UserString.of(new String(chars));
}
}
private static class ExtractResultDoFn
extends DoFn<KV<String, CoCombineResult>, KV<String, KV<Integer, String>>> {
private final TupleTag<Integer> maxIntTag;
private final TupleTag<UserString> concatStringTag;
ExtractResultDoFn(TupleTag<Integer> maxIntTag, TupleTag<UserString> concatStringTag) {
this.maxIntTag = maxIntTag;
this.concatStringTag = concatStringTag;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
UserString userString = c.element().getValue().get(concatStringTag);
KV<Integer, String> value =
KV.of(
c.element().getValue().get(maxIntTag),
userString == null ? null : userString.strValue);
c.output(KV.of(c.element().getKey(), value));
}
}
}