blob: 4591ca0f44a246f6a0fdaa4920c37232f8e32333 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.util.common.worker;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.when;
import java.io.Closeable;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions;
import org.apache.beam.runners.dataflow.worker.BatchModeExecutionContext;
import org.apache.beam.runners.dataflow.worker.DataflowOperationContext.DataflowExecutionState;
import org.apache.beam.runners.dataflow.worker.ExperimentContext.Experiment;
import org.apache.beam.runners.dataflow.worker.TestOperationContext.TestDataflowExecutionState;
import org.apache.beam.runners.dataflow.worker.counters.Counter;
import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.common.Reiterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
/** Tests for {@link GroupingShuffleEntryIterator}. */
@RunWith(JUnit4.class)
public class GroupingShuffleEntryIteratorTest {
private static final ByteArrayShufflePosition START_POSITION =
ByteArrayShufflePosition.of("aaa".getBytes(StandardCharsets.UTF_8));
private static final ByteArrayShufflePosition END_POSITION =
ByteArrayShufflePosition.of("zzz".getBytes(StandardCharsets.UTF_8));
private static final String MOCK_STAGE_NAME = "mockStageName";
private static final String MOCK_ORIGINAL_NAME_FOR_EXECUTING_STEP1 = "mockOriginalName1";
private static final String MOCK_SYSTEM_NAME = "mockSystemName";
private static final String MOCK_USER_NAME = "mockUserName";
private static final String ORIGINAL_SHUFFLE_STEP_NAME = "originalName";
@Mock private ShuffleEntryReader reader;
private GroupingShuffleEntryIterator iterator;
private final ExecutionStateSampler sampler = ExecutionStateSampler.newForTest();
private final ExecutionStateTracker tracker = new ExecutionStateTracker(sampler);
private Closeable trackerCleanup;
@Before
public void setUp() {
trackerCleanup = tracker.activate();
}
@After
public void tearDown() throws IOException {
trackerCleanup.close();
}
private static class ListReiterator<T> implements Reiterator<T> {
protected final List<T> entries;
protected int nextIndex;
public ListReiterator(List<T> entries, int nextIndex) {
this.entries = entries;
this.nextIndex = nextIndex;
}
@Override
public Reiterator<T> copy() {
return new ListReiterator<T>(entries, nextIndex);
}
@Override
public boolean hasNext() {
return nextIndex < entries.size();
}
@Override
public T next() {
T res = entries.get(nextIndex);
nextIndex++;
return res;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
private void setCurrentExecutionState(String mockOriginalName) {
DataflowExecutionState state =
new TestDataflowExecutionState(
NameContext.create(MOCK_STAGE_NAME, mockOriginalName, MOCK_SYSTEM_NAME, MOCK_USER_NAME),
"activity");
tracker.enterState(state);
}
private static ShuffleEntry shuffleEntry(String key, String value) {
return new ShuffleEntry(
/* use key itself as position */
ByteArrayShufflePosition.of(key.getBytes(Charsets.UTF_8)),
key.getBytes(Charsets.UTF_8),
new byte[0],
value.getBytes(Charsets.UTF_8));
}
@Test
public void testCopyValuesIterator() throws Exception {
setCurrentExecutionState(MOCK_ORIGINAL_NAME_FOR_EXECUTING_STEP1);
MockitoAnnotations.initMocks(this);
PipelineOptions options = PipelineOptionsFactory.create();
options
.as(DataflowPipelineDebugOptions.class)
.setExperiments(Lists.newArrayList(Experiment.IntertransformIO.getName()));
BatchModeExecutionContext spyExecutionContext =
Mockito.spy(BatchModeExecutionContext.forTesting(options, "STAGE"));
ArrayList<ShuffleEntry> entries = new ArrayList<>();
entries.add(shuffleEntry("k1", "v11"));
entries.add(shuffleEntry("k1", "v12"));
entries.add(shuffleEntry("k1", "v13"));
when(reader.read(START_POSITION, END_POSITION)).thenReturn(new ListReiterator<>(entries, 0));
final ShuffleReadCounter shuffleReadCounter =
new ShuffleReadCounter(ORIGINAL_SHUFFLE_STEP_NAME, true, null);
iterator =
new GroupingShuffleEntryIterator(reader, START_POSITION, END_POSITION) {
@Override
protected void notifyElementRead(long byteSize) {
// nothing
}
@Override
protected void commitBytesRead(long bytes) {
shuffleReadCounter.addBytesRead(bytes);
}
};
assertTrue(iterator.advance());
KeyGroupedShuffleEntries k1Entries = iterator.getCurrent();
Reiterator<ShuffleEntry> values1 = k1Entries.values.iterator();
assertTrue(values1.hasNext());
Reiterator<ShuffleEntry> values1Copy1 = values1.copy();
Reiterator<ShuffleEntry> values1Copy2 = values1.copy();
ShuffleEntry expectedEntry = values1.next();
assertFalse(iterator.advance()); // Advance the iterator again to record bytes read.
// Test that the copy works two ways: 1) if we call hasNext, and 2) if we don't.
assertTrue(values1Copy1.hasNext());
assertEquals(expectedEntry, values1Copy1.next());
assertEquals(expectedEntry, values1Copy2.next());
Map<String, Long> expectedReadBytesMap = new HashMap<>();
expectedReadBytesMap.put(MOCK_ORIGINAL_NAME_FOR_EXECUTING_STEP1, 15L);
// Verify that each executing step used when reading from the GroupingShuffleReader
// has a counter with a bytes read value.
assertEquals(expectedReadBytesMap.size(), (long) shuffleReadCounter.counterSet.size());
Iterator it = expectedReadBytesMap.entrySet().iterator();
while (it.hasNext()) {
Map.Entry<String, Long> pair = (Map.Entry) it.next();
Counter counter =
shuffleReadCounter.counterSet.getExistingCounter(
ShuffleReadCounter.generateCounterName(ORIGINAL_SHUFFLE_STEP_NAME, pair.getKey()));
assertEquals(pair.getValue(), counter.getAggregate());
}
}
/** A ShuffleEntryReader that asserts that its iterators never go backwards ("reiterate"). */
private static class ForwardOnlyShuffleEntryReader implements ShuffleEntryReader {
private final List<ShuffleEntry> entries;
private int minNextIndex = 0; // The smallest index that iterators are allowed to advance.
public ForwardOnlyShuffleEntryReader(List<ShuffleEntry> entries) {
this.entries = entries;
}
@Override
public Reiterator<ShuffleEntry> read(
@Nullable ShufflePosition startPosition, @Nullable ShufflePosition endPosition) {
return new MinIndexAssertingListReiterator(entries, 0);
}
@Override
public void close() {}
private class MinIndexAssertingListReiterator<T> extends ListReiterator<T> {
public MinIndexAssertingListReiterator(List<T> entries, int nextIndex) {
super(entries, nextIndex);
}
@Override
public Reiterator<T> copy() {
return new MinIndexAssertingListReiterator<T>(entries, nextIndex);
}
@Override
public T next() {
assertTrue("Reiteration unexpected.", nextIndex >= minNextIndex);
minNextIndex = nextIndex;
return super.next();
}
}
}
/**
* Tests that GroupingShuffleEntryIterator does not reiterate the underlying shuffle iterator when
* the returned value iterators are iterated over (i.e., that fast-forwarding works properly).
*/
@Test
public void testNoReiteration() throws Exception {
ArrayList<ShuffleEntry> entries = new ArrayList<>();
entries.add(shuffleEntry("k1", "v11"));
entries.add(shuffleEntry("k1", "v12"));
entries.add(shuffleEntry("k1", "v13"));
entries.add(shuffleEntry("k2", "v21"));
entries.add(shuffleEntry("k2", "v22"));
entries.add(shuffleEntry("k2", "v23"));
ForwardOnlyShuffleEntryReader reader = new ForwardOnlyShuffleEntryReader(entries);
iterator =
new GroupingShuffleEntryIterator(reader, START_POSITION, END_POSITION) {
@Override
protected void notifyElementRead(long byteSize) {
// nothing
}
@Override
protected void commitBytesRead(long bytes) {
// nothing
}
};
int totalKeys = 0;
int totalValues = 0;
while (iterator.advance()) {
++totalKeys;
Reiterator<ShuffleEntry> values = iterator.getCurrent().values.iterator();
while (values.hasNext()) {
values.next();
++totalValues;
}
}
assertEquals(2, totalKeys);
assertEquals(6, totalValues);
// We expect that AssertionException in MinIndexAssertingListReiterator.next() is not thrown.
}
}