blob: 966710de2a942527260e9a9bc222e337c765cadc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker;
import static org.apache.beam.runners.dataflow.worker.util.common.worker.TestOutputReceiver.TestOutputCounter.getMeanByteCounterName;
import static org.apache.beam.runners.dataflow.worker.util.common.worker.TestOutputReceiver.TestOutputCounter.getObjectCounterName;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Set;
import org.apache.beam.runners.core.NullSideInputReader;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StepContext;
import org.apache.beam.runners.dataflow.worker.IntrinsicMapTaskExecutorFactory.ElementByteSizeObservableCoder;
import org.apache.beam.runners.dataflow.worker.PartialGroupByKeyParDoFns.BatchSideInputPGBKParDoFn;
import org.apache.beam.runners.dataflow.worker.PartialGroupByKeyParDoFns.CoderSizeEstimator;
import org.apache.beam.runners.dataflow.worker.PartialGroupByKeyParDoFns.PairInfo;
import org.apache.beam.runners.dataflow.worker.PartialGroupByKeyParDoFns.StreamingSideInputPGBKParDoFn;
import org.apache.beam.runners.dataflow.worker.PartialGroupByKeyParDoFns.WindowingCoderGroupingKeyCreator;
import org.apache.beam.runners.dataflow.worker.counters.Counter.CounterUpdateExtractor;
import org.apache.beam.runners.dataflow.worker.counters.CounterFactory.LongCounterMean;
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
import org.apache.beam.runners.dataflow.worker.util.common.worker.GroupingTables;
import org.apache.beam.runners.dataflow.worker.util.common.worker.GroupingTables.Combiner;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn;
import org.apache.beam.runners.dataflow.worker.util.common.worker.SimplePartialGroupByKeyParDoFn;
import org.apache.beam.runners.dataflow.worker.util.common.worker.TestOutputReceiver;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.AppliedCombineFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.hamcrest.collection.IsIterableContainingInAnyOrder;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
/** Tests for {@link PartialGroupByKeyParDoFns}. */
@RunWith(JUnit4.class)
@SuppressWarnings({"rawtypes", "unchecked"})
public class PartialGroupByKeyParDoFnsTest {
@Mock private StreamingSideInputFetcher<KV<String, Integer>, BoundedWindow> mockSideInputFetcher;
@Mock private BagState<WindowedValue<KV<String, Integer>>> elemsBag;
@Mock private SideInputReader mockSideInputReader;
@Mock private StreamingModeExecutionContext.StepContext mockStreamingStepContext;
@Mock private StateInternals mockStateInternals;
@Mock private ValueState mockState;
@Mock private Coder<String> mockCoder;
private final CounterSet counterSet = new CounterSet();
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
}
@Test
public void testPartialGroupByKey() throws Exception {
Coder keyCoder = StringUtf8Coder.of();
Coder valueCoder = BigEndianIntegerCoder.of();
TestOutputReceiver receiver =
new TestOutputReceiver(
new ElementByteSizeObservableCoder(
WindowedValue.getValueOnlyCoder(
KvCoder.of(keyCoder, IterableCoder.of(valueCoder)))),
counterSet,
NameContextsForTests.nameContextForTest());
ParDoFn pgbkParDoFn =
new SimplePartialGroupByKeyParDoFn(
GroupingTables.buffering(
new WindowingCoderGroupingKeyCreator(keyCoder),
PairInfo.create(),
new CoderSizeEstimator(WindowedValue.getValueOnlyCoder(keyCoder)),
new CoderSizeEstimator(valueCoder)),
receiver);
pgbkParDoFn.startBundle(receiver);
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 4)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("there", 5)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 6)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("joe", 7)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("there", 8)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 9)));
pgbkParDoFn.finishBundle();
assertThat(
receiver.outputElems,
IsIterableContainingInAnyOrder.<Object>containsInAnyOrder(
WindowedValue.valueInGlobalWindow(KV.of("hi", Arrays.asList(4, 6, 9))),
WindowedValue.valueInGlobalWindow(KV.of("there", Arrays.asList(5, 8))),
WindowedValue.valueInGlobalWindow(KV.of("joe", Arrays.asList(7)))));
// Exact counter values depend on size of encoded data. If encoding
// changes, then these expected counters should change to match.
CounterUpdateExtractor<?> updateExtractor = Mockito.mock(CounterUpdateExtractor.class);
counterSet.extractUpdates(false, updateExtractor);
verify(updateExtractor).longSum(getObjectCounterName("test_receiver_out"), false, 3L);
verify(updateExtractor)
.longMean(
getMeanByteCounterName("test_receiver_out"),
false,
LongCounterMean.ZERO.addValue(49L, 3));
verifyNoMoreInteractions(updateExtractor);
}
@Test
public void testPartialGroupByKeyWithCombiner() throws Exception {
Coder keyCoder = StringUtf8Coder.of();
Coder valueCoder = BigEndianIntegerCoder.of();
TestOutputReceiver receiver =
new TestOutputReceiver(
new ElementByteSizeObservableCoder(
WindowedValue.getValueOnlyCoder(KvCoder.of(keyCoder, valueCoder))),
counterSet,
NameContextsForTests.nameContextForTest());
Combiner<WindowedValue<String>, Integer, Integer, Integer> combineFn = new TestCombiner();
ParDoFn pgbkParDoFn =
new SimplePartialGroupByKeyParDoFn(
GroupingTables.combining(
new WindowingCoderGroupingKeyCreator(keyCoder),
PairInfo.create(),
combineFn,
new CoderSizeEstimator(WindowedValue.getValueOnlyCoder(keyCoder)),
new CoderSizeEstimator(valueCoder)),
receiver);
pgbkParDoFn.startBundle(receiver);
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 4)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("there", 5)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 6)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("joe", 7)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("there", 8)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 9)));
pgbkParDoFn.finishBundle();
assertThat(
receiver.outputElems,
IsIterableContainingInAnyOrder.<Object>containsInAnyOrder(
WindowedValue.valueInGlobalWindow(KV.of("hi", 19)),
WindowedValue.valueInGlobalWindow(KV.of("there", 13)),
WindowedValue.valueInGlobalWindow(KV.of("joe", 7))));
// Exact counter values depend on size of encoded data. If encoding
// changes, then these expected counters should change to match.
CounterUpdateExtractor<?> updateExtractor = Mockito.mock(CounterUpdateExtractor.class);
counterSet.extractUpdates(false, updateExtractor);
verify(updateExtractor).longSum(getObjectCounterName("test_receiver_out"), false, 3L);
verify(updateExtractor)
.longMean(
getMeanByteCounterName("test_receiver_out"),
false,
LongCounterMean.ZERO.addValue(25L, 3));
verifyNoMoreInteractions(updateExtractor);
}
@Test
public void testPartialGroupByKeyWithCombinerAndSideInputs() throws Exception {
Coder keyCoder = StringUtf8Coder.of();
Coder valueCoder = BigEndianIntegerCoder.of();
TestOutputReceiver receiver =
new TestOutputReceiver(
new ElementByteSizeObservableCoder(
WindowedValue.getValueOnlyCoder(KvCoder.of(keyCoder, valueCoder))),
counterSet,
NameContextsForTests.nameContextForTest());
Combiner<WindowedValue<String>, Integer, Integer, Integer> combineFn = new TestCombiner();
ParDoFn pgbkParDoFn =
new StreamingSideInputPGBKParDoFn(
GroupingTables.combining(
new WindowingCoderGroupingKeyCreator(keyCoder),
PairInfo.create(),
combineFn,
new CoderSizeEstimator(WindowedValue.getValueOnlyCoder(keyCoder)),
new CoderSizeEstimator(valueCoder)),
receiver,
mockSideInputFetcher);
Set<BoundedWindow> readyWindows = ImmutableSet.<BoundedWindow>of(GlobalWindow.INSTANCE);
when(mockSideInputFetcher.getReadyWindows()).thenReturn(readyWindows);
when(mockSideInputFetcher.prefetchElements(readyWindows))
.thenReturn(ImmutableList.of(elemsBag));
when(elemsBag.read())
.thenReturn(
ImmutableList.of(
WindowedValue.valueInGlobalWindow(KV.of("hi", 4)),
WindowedValue.valueInGlobalWindow(KV.of("there", 5))));
when(mockSideInputFetcher.storeIfBlocked(Matchers.<WindowedValue<KV<String, Integer>>>any()))
.thenReturn(false, false, false, true);
pgbkParDoFn.startBundle(receiver);
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 6)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("joe", 7)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("there", 8)));
pgbkParDoFn.processElement(WindowedValue.valueInGlobalWindow(KV.of("hi", 9)));
pgbkParDoFn.finishBundle();
assertThat(
receiver.outputElems,
IsIterableContainingInAnyOrder.<Object>containsInAnyOrder(
WindowedValue.valueInGlobalWindow(KV.of("hi", 10)),
WindowedValue.valueInGlobalWindow(KV.of("there", 13)),
WindowedValue.valueInGlobalWindow(KV.of("joe", 7))));
// Exact counter values depend on size of encoded data. If encoding
// changes, then these expected counters should change to match.
CounterUpdateExtractor<?> updateExtractor = Mockito.mock(CounterUpdateExtractor.class);
counterSet.extractUpdates(false, updateExtractor);
verify(updateExtractor).longSum(getObjectCounterName("test_receiver_out"), false, 3L);
verify(updateExtractor)
.longMean(
getMeanByteCounterName("test_receiver_out"),
false,
LongCounterMean.ZERO.addValue(25L, 3));
verifyNoMoreInteractions(updateExtractor);
}
@Test
public void testCreateWithCombinerAndBatchSideInputs() throws Exception {
PipelineOptions options = PipelineOptionsFactory.create();
Coder keyCoder = StringUtf8Coder.of();
Coder valueCoder = BigEndianIntegerCoder.of();
KvCoder<String, Integer> kvCoder = KvCoder.of(keyCoder, valueCoder);
TestOutputReceiver receiver =
new TestOutputReceiver(
new ElementByteSizeObservableCoder(WindowedValue.getValueOnlyCoder(kvCoder)),
counterSet,
NameContextsForTests.nameContextForTest());
StepContext stepContext =
BatchModeExecutionContext.forTesting(options, "testStage")
.getStepContext(TestOperationContext.create(counterSet));
when(mockSideInputReader.isEmpty()).thenReturn(false);
ParDoFn pgbk =
PartialGroupByKeyParDoFns.create(
options,
kvCoder,
AppliedCombineFn.withInputCoder(
Sum.ofIntegers(),
CoderRegistry.createDefault(),
kvCoder,
ImmutableList.<PCollectionView<?>>of(),
WindowingStrategy.globalDefault()),
mockSideInputReader,
receiver,
stepContext);
assertTrue(pgbk instanceof BatchSideInputPGBKParDoFn);
}
@Test
public void testCreateWithCombinerAndStreaming() throws Exception {
StreamingOptions options = PipelineOptionsFactory.as(StreamingOptions.class);
options.setStreaming(true);
Coder keyCoder = StringUtf8Coder.of();
Coder valueCoder = BigEndianIntegerCoder.of();
KvCoder<String, Integer> kvCoder = KvCoder.of(keyCoder, valueCoder);
TestOutputReceiver receiver =
new TestOutputReceiver(
new ElementByteSizeObservableCoder(WindowedValue.getValueOnlyCoder(kvCoder)),
counterSet,
NameContextsForTests.nameContextForTest());
ParDoFn pgbk =
PartialGroupByKeyParDoFns.create(
options,
kvCoder,
AppliedCombineFn.withInputCoder(
Sum.ofIntegers(), CoderRegistry.createDefault(), kvCoder),
NullSideInputReader.empty(),
receiver,
null);
assertTrue(pgbk instanceof SimplePartialGroupByKeyParDoFn);
}
@Test
public void testCreateWithCombinerAndStreamingSideInputs() throws Exception {
StreamingOptions options = PipelineOptionsFactory.as(StreamingOptions.class);
options.setStreaming(true);
Coder keyCoder = StringUtf8Coder.of();
Coder valueCoder = BigEndianIntegerCoder.of();
KvCoder<String, Integer> kvCoder = KvCoder.of(keyCoder, valueCoder);
TestOutputReceiver receiver =
new TestOutputReceiver(
new ElementByteSizeObservableCoder(WindowedValue.getValueOnlyCoder(kvCoder)),
counterSet,
NameContextsForTests.nameContextForTest());
when(mockSideInputReader.isEmpty()).thenReturn(false);
when(mockStreamingStepContext.stateInternals()).thenReturn((StateInternals) mockStateInternals);
when(mockStateInternals.state(Matchers.<StateNamespace>any(), Matchers.<StateTag>any()))
.thenReturn(mockState);
when(mockState.read()).thenReturn(Maps.newHashMap());
ParDoFn pgbk =
PartialGroupByKeyParDoFns.create(
options,
kvCoder,
AppliedCombineFn.withInputCoder(
Sum.ofIntegers(),
CoderRegistry.createDefault(),
kvCoder,
ImmutableList.<PCollectionView<?>>of(),
WindowingStrategy.globalDefault()),
mockSideInputReader,
receiver,
mockStreamingStepContext);
assertTrue(pgbk instanceof StreamingSideInputPGBKParDoFn);
}
/**
* Test that CoderSizeEstimator correctly uses ElementByteSizeObserver if it reports itself as
* non-lazy
*/
@Test
public void testCoderSizeEstimationWithNonLazyObserver() throws Exception {
doAnswer(
invocation -> {
Object[] args = invocation.getArguments();
((ElementByteSizeObserver) args[1]).update(5);
// Observer assumed non-lazy by default if setLazy not called
return null;
})
.when(mockCoder)
.registerByteSizeObserver(Matchers.eq("apple"), Matchers.<ElementByteSizeObserver>any());
CoderSizeEstimator<String> estimator = new CoderSizeEstimator(mockCoder);
assertEquals(5, estimator.estimateSize("apple"));
}
@Test
public void testCoderSizeEstimationWithLazyObserver() throws Exception {
// Have the code report that byte size observation is lazy
doAnswer(
invocation -> {
Object[] args = invocation.getArguments();
// Set lazy to true and do not update the size at all
((ElementByteSizeObserver) args[1]).setLazy();
return null;
})
.when(mockCoder)
.registerByteSizeObserver(Matchers.eq("apple"), Matchers.<ElementByteSizeObserver>any());
// Encode the input to the output stream
doAnswer(
invocation -> {
Object[] args = invocation.getArguments();
String value = (String) args[0];
OutputStream os = (OutputStream) args[1];
os.write(value.getBytes(StandardCharsets.UTF_8));
return null;
})
.when(mockCoder)
.encode(Matchers.eq("apple"), Matchers.<OutputStream>any());
CoderSizeEstimator<String> estimator = new CoderSizeEstimator(mockCoder);
// Observer never updates size, so if result is 5, must have delegated to actual encoding
assertEquals(5L, estimator.estimateSize("apple"));
}
private static class TestCombiner
implements Combiner<WindowedValue<String>, Integer, Integer, Integer> {
@Override
public Integer createAccumulator(WindowedValue<String> key) {
return 0;
}
@Override
public Integer add(WindowedValue<String> key, Integer accumulator, Integer value) {
return accumulator + value;
}
@Override
public Integer merge(WindowedValue<String> key, Iterable<Integer> accumulators) {
Integer sum = 0;
for (Integer part : accumulators) {
sum += part;
}
return sum;
}
@Override
public Integer compact(WindowedValue<String> key, Integer accumulator) {
return accumulator;
}
@Override
public Integer extract(WindowedValue<String> key, Integer accumulator) {
return accumulator;
}
}
}