blob: 360886231ad95c52da88aac29eff6e4a7400ad3a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker;
import static org.apache.beam.runners.dataflow.util.Structs.addString;
import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import com.google.api.services.dataflow.model.SideInputInfo;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.PropertyNames;
import org.apache.beam.runners.dataflow.worker.util.WorkerPropertyNames;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Receiver;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.util.AppliedCombineFn;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link CombineValuesFnFactory}. */
@RunWith(JUnit4.class)
public class CombineValuesFnFactoryTest {
/** Example AccumulatingCombineFn. */
public static class MeanInts extends Combine.AccumulatingCombineFn<Integer, CountSum, String> {
@Override
public CountSum createAccumulator() {
return new CountSum(0, 0.0);
}
@Override
public Coder<CountSum> getAccumulatorCoder(CoderRegistry registry, Coder<Integer> inputCoder) {
return new CountSumCoder();
}
}
static class CountSum
implements Combine.AccumulatingCombineFn.Accumulator<Integer, CountSum, String> {
long count;
double sum;
@Override
public void addInput(Integer element) {
count++;
sum += element.doubleValue();
}
@Override
public void mergeAccumulator(CountSum accumulator) {
count += accumulator.count;
sum += accumulator.sum;
}
@Override
public String extractOutput() {
return String.format("%.1f", count == 0 ? 0.0 : sum / count);
}
public CountSum(long count, double sum) {
this.count = count;
this.sum = sum;
}
@Override
public int hashCode() {
return Objects.hash(count, sum);
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof CountSum)) {
return false;
}
CountSum other = (CountSum) obj;
return (this.count == other.count) && (Math.abs(this.sum - other.sum) < 0.1);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this).add("count", count).add("sum", sum).toString();
}
}
/** An example "cheap" accumulator coder. */
public static class CountSumCoder extends CustomCoder<CountSum> {
public CountSumCoder() {}
@Override
public void encode(CountSum value, OutputStream outStream) throws CoderException, IOException {
DataOutputStream dataStream = new DataOutputStream(outStream);
dataStream.writeLong(value.count);
dataStream.writeDouble(value.sum);
}
@Override
public CountSum decode(InputStream inStream) throws CoderException, IOException {
DataInputStream dataStream = new DataInputStream(inStream);
long count = dataStream.readLong();
double sum = dataStream.readDouble();
return new CountSum(count, sum);
}
@Override
public void verifyDeterministic() {}
@Override
public boolean isRegisterByteSizeObserverCheap(CountSum value) {
return true;
}
@Override
public void registerByteSizeObserver(CountSum value, ElementByteSizeObserver observer)
throws Exception {
observer.update((long) 16);
}
}
static class TestReceiver implements Receiver {
List<Object> receivedElems = new ArrayList<>();
@Override
public void process(Object outputElem) {
receivedElems.add(outputElem);
}
}
private static final ParDoFnFactory parDoFnFactory = new CombineValuesFnFactory();
private static final TupleTag<?> MAIN_OUTPUT = new TupleTag<>("output");
private <K, InputT, AccumT, OutputT> ParDoFn createCombineValuesFn(
String phase,
Combine.CombineFn<InputT, AccumT, OutputT> combineFn,
Coder<K> keyCoder,
Coder<InputT> inputCoder,
Coder<AccumT> accumCoder,
WindowingStrategy<?, ?> windowingStrategy)
throws Exception {
// This partially mirrors the work that
// org.apache.beam.runners.dataflow.worker.transforms.Combine.translateHelper
// does, at least for the KeyedCombineFn. The phase is generated
// by the back-end.
CloudObject spec = CloudObject.forClassName("CombineValuesFn");
@SuppressWarnings("unchecked")
AppliedCombineFn appliedCombineFn =
AppliedCombineFn.withAccumulatorCoder(
combineFn,
accumCoder,
Collections.emptyList(),
KvCoder.of(keyCoder, inputCoder),
windowingStrategy);
addString(
spec,
PropertyNames.SERIALIZED_FN,
byteArrayToJsonString(serializeToByteArray(appliedCombineFn)));
addString(spec, WorkerPropertyNames.PHASE, phase);
return parDoFnFactory.create(
PipelineOptionsFactory.create(),
spec,
ImmutableList.<SideInputInfo>of(),
MAIN_OUTPUT,
ImmutableMap.<TupleTag<?>, Integer>of(MAIN_OUTPUT, 0),
BatchModeExecutionContext.forTesting(PipelineOptionsFactory.create(), "testStage"),
TestOperationContext.create());
}
@Test
public void testCombineValuesFnAll() throws Exception {
TestReceiver receiver = new TestReceiver();
Combine.CombineFn<Integer, CountSum, String> combiner = (new MeanInts());
ParDoFn combineParDoFn =
createCombineValuesFn(
CombinePhase.ALL,
combiner,
StringUtf8Coder.of(),
BigEndianIntegerCoder.of(),
new CountSumCoder(),
WindowingStrategy.globalDefault());
combineParDoFn.startBundle(receiver);
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("a", Arrays.asList(5, 6, 7))));
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("b", Arrays.asList(1, 3, 7))));
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("c", Arrays.asList(3, 6, 8, 9))));
combineParDoFn.finishBundle();
Object[] expectedReceivedElems = {
WindowedValue.valueInGlobalWindow(KV.of("a", String.format("%.1f", 6.0))),
WindowedValue.valueInGlobalWindow(KV.of("b", String.format("%.1f", 3.7))),
WindowedValue.valueInGlobalWindow(KV.of("c", String.format("%.1f", 6.5))),
};
assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray());
}
@Test
public void testCombineValuesFnAdd() throws Exception {
TestReceiver receiver = new TestReceiver();
MeanInts mean = new MeanInts();
Combine.CombineFn<Integer, CountSum, String> combiner = mean;
ParDoFn combineParDoFn =
createCombineValuesFn(
CombinePhase.ADD,
combiner,
StringUtf8Coder.of(),
BigEndianIntegerCoder.of(),
new CountSumCoder(),
WindowingStrategy.globalDefault());
combineParDoFn.startBundle(receiver);
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("a", Arrays.asList(5, 6, 7))));
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("b", Arrays.asList(1, 3, 7))));
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("c", Arrays.asList(3, 6, 8, 9))));
combineParDoFn.finishBundle();
Object[] expectedReceivedElems = {
WindowedValue.valueInGlobalWindow(KV.of("a", new CountSum(3, 18))),
WindowedValue.valueInGlobalWindow(KV.of("b", new CountSum(3, 11))),
WindowedValue.valueInGlobalWindow(KV.of("c", new CountSum(4, 26)))
};
assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray());
}
@Test
public void testCombineValuesFnMerge() throws Exception {
TestReceiver receiver = new TestReceiver();
MeanInts mean = new MeanInts();
Combine.CombineFn<Integer, CountSum, String> combiner = mean;
ParDoFn combineParDoFn =
createCombineValuesFn(
CombinePhase.MERGE,
combiner,
StringUtf8Coder.of(),
BigEndianIntegerCoder.of(),
new CountSumCoder(),
WindowingStrategy.globalDefault());
combineParDoFn.startBundle(receiver);
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(
KV.of(
"a", Arrays.asList(new CountSum(3, 6), new CountSum(2, 9), new CountSum(1, 12)))));
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(
KV.of("b", Arrays.asList(new CountSum(2, 20), new CountSum(1, 1)))));
combineParDoFn.finishBundle();
Object[] expectedReceivedElems = {
WindowedValue.valueInGlobalWindow(KV.of("a", new CountSum(6, 27))),
WindowedValue.valueInGlobalWindow(KV.of("b", new CountSum(3, 21))),
};
assertArrayEquals(expectedReceivedElems, receiver.receivedElems.toArray());
}
@Test
public void testCombineValuesFnExtract() throws Exception {
TestReceiver receiver = new TestReceiver();
MeanInts mean = new MeanInts();
Combine.CombineFn<Integer, CountSum, String> combiner = mean;
ParDoFn combineParDoFn =
createCombineValuesFn(
CombinePhase.EXTRACT,
combiner,
StringUtf8Coder.of(),
BigEndianIntegerCoder.of(),
new CountSumCoder(),
WindowingStrategy.globalDefault());
combineParDoFn.startBundle(receiver);
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("a", new CountSum(6, 27))));
combineParDoFn.processElement(
WindowedValue.valueInGlobalWindow(KV.of("b", new CountSum(3, 21))));
combineParDoFn.finishBundle();
assertArrayEquals(
new Object[] {
WindowedValue.valueInGlobalWindow(KV.of("a", String.format("%.1f", 4.5))),
WindowedValue.valueInGlobalWindow(KV.of("b", String.format("%.1f", 7.0)))
},
receiver.receivedElems.toArray());
}
@Test
public void testCombineValuesFnCoders() throws Exception {
CoderRegistry registry = CoderRegistry.createDefault();
MeanInts meanInts = new MeanInts();
CountSum countSum = new CountSum(6, 27);
Coder<CountSum> coder =
meanInts.getAccumulatorCoder(registry, registry.getCoder(TypeDescriptor.of(Integer.class)));
assertEquals(
countSum,
CoderUtils.decodeFromByteArray(coder, CoderUtils.encodeToByteArray(coder, countSum)));
}
}