| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.dataflow.worker.util.common.worker; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.assertTrue; |
| import static org.mockito.Mockito.when; |
| |
| import java.io.Closeable; |
| import java.io.IOException; |
| import java.nio.charset.StandardCharsets; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| import javax.annotation.Nullable; |
| import org.apache.beam.runners.core.metrics.ExecutionStateSampler; |
| import org.apache.beam.runners.core.metrics.ExecutionStateTracker; |
| import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; |
| import org.apache.beam.runners.dataflow.worker.BatchModeExecutionContext; |
| import org.apache.beam.runners.dataflow.worker.DataflowOperationContext.DataflowExecutionState; |
| import org.apache.beam.runners.dataflow.worker.ExperimentContext.Experiment; |
| import org.apache.beam.runners.dataflow.worker.TestOperationContext.TestDataflowExecutionState; |
| import org.apache.beam.runners.dataflow.worker.counters.Counter; |
| import org.apache.beam.runners.dataflow.worker.counters.NameContext; |
| import org.apache.beam.sdk.options.PipelineOptions; |
| import org.apache.beam.sdk.options.PipelineOptionsFactory; |
| import org.apache.beam.sdk.util.common.Reiterator; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| import org.mockito.Mock; |
| import org.mockito.Mockito; |
| import org.mockito.MockitoAnnotations; |
| |
| /** Tests for {@link GroupingShuffleEntryIterator}. */ |
| @RunWith(JUnit4.class) |
| public class GroupingShuffleEntryIteratorTest { |
| private static final ByteArrayShufflePosition START_POSITION = |
| ByteArrayShufflePosition.of("aaa".getBytes(StandardCharsets.UTF_8)); |
| private static final ByteArrayShufflePosition END_POSITION = |
| ByteArrayShufflePosition.of("zzz".getBytes(StandardCharsets.UTF_8)); |
| |
| private static final String MOCK_STAGE_NAME = "mockStageName"; |
| private static final String MOCK_ORIGINAL_NAME_FOR_EXECUTING_STEP1 = "mockOriginalName1"; |
| private static final String MOCK_SYSTEM_NAME = "mockSystemName"; |
| private static final String MOCK_USER_NAME = "mockUserName"; |
| private static final String ORIGINAL_SHUFFLE_STEP_NAME = "originalName"; |
| |
| @Mock private ShuffleEntryReader reader; |
| private GroupingShuffleEntryIterator iterator; |
| |
| private final ExecutionStateSampler sampler = ExecutionStateSampler.newForTest(); |
| private final ExecutionStateTracker tracker = new ExecutionStateTracker(sampler); |
| private Closeable trackerCleanup; |
| |
| @Before |
| public void setUp() { |
| trackerCleanup = tracker.activate(); |
| } |
| |
| @After |
| public void tearDown() throws IOException { |
| trackerCleanup.close(); |
| } |
| |
| private static class ListReiterator<T> implements Reiterator<T> { |
| protected final List<T> entries; |
| protected int nextIndex; |
| |
| public ListReiterator(List<T> entries, int nextIndex) { |
| this.entries = entries; |
| this.nextIndex = nextIndex; |
| } |
| |
| @Override |
| public Reiterator<T> copy() { |
| return new ListReiterator<T>(entries, nextIndex); |
| } |
| |
| @Override |
| public boolean hasNext() { |
| return nextIndex < entries.size(); |
| } |
| |
| @Override |
| public T next() { |
| T res = entries.get(nextIndex); |
| nextIndex++; |
| return res; |
| } |
| |
| @Override |
| public void remove() { |
| throw new UnsupportedOperationException(); |
| } |
| } |
| |
| private void setCurrentExecutionState(String mockOriginalName) { |
| DataflowExecutionState state = |
| new TestDataflowExecutionState( |
| NameContext.create(MOCK_STAGE_NAME, mockOriginalName, MOCK_SYSTEM_NAME, MOCK_USER_NAME), |
| "activity"); |
| tracker.enterState(state); |
| } |
| |
| private static ShuffleEntry shuffleEntry(String key, String value) { |
| return new ShuffleEntry( |
| /* use key itself as position */ |
| ByteArrayShufflePosition.of(key.getBytes(Charsets.UTF_8)), |
| key.getBytes(Charsets.UTF_8), |
| new byte[0], |
| value.getBytes(Charsets.UTF_8)); |
| } |
| |
| @Test |
| public void testCopyValuesIterator() throws Exception { |
| setCurrentExecutionState(MOCK_ORIGINAL_NAME_FOR_EXECUTING_STEP1); |
| |
| MockitoAnnotations.initMocks(this); |
| PipelineOptions options = PipelineOptionsFactory.create(); |
| options |
| .as(DataflowPipelineDebugOptions.class) |
| .setExperiments(Lists.newArrayList(Experiment.IntertransformIO.getName())); |
| BatchModeExecutionContext spyExecutionContext = |
| Mockito.spy(BatchModeExecutionContext.forTesting(options, "STAGE")); |
| |
| ArrayList<ShuffleEntry> entries = new ArrayList<>(); |
| entries.add(shuffleEntry("k1", "v11")); |
| entries.add(shuffleEntry("k1", "v12")); |
| entries.add(shuffleEntry("k1", "v13")); |
| when(reader.read(START_POSITION, END_POSITION)).thenReturn(new ListReiterator<>(entries, 0)); |
| |
| final ShuffleReadCounter shuffleReadCounter = |
| new ShuffleReadCounter(ORIGINAL_SHUFFLE_STEP_NAME, true, null); |
| iterator = |
| new GroupingShuffleEntryIterator(reader, START_POSITION, END_POSITION) { |
| @Override |
| protected void notifyElementRead(long byteSize) { |
| // nothing |
| } |
| |
| @Override |
| protected void commitBytesRead(long bytes) { |
| shuffleReadCounter.addBytesRead(bytes); |
| } |
| }; |
| assertTrue(iterator.advance()); |
| KeyGroupedShuffleEntries k1Entries = iterator.getCurrent(); |
| Reiterator<ShuffleEntry> values1 = k1Entries.values.iterator(); |
| assertTrue(values1.hasNext()); |
| Reiterator<ShuffleEntry> values1Copy1 = values1.copy(); |
| Reiterator<ShuffleEntry> values1Copy2 = values1.copy(); |
| ShuffleEntry expectedEntry = values1.next(); |
| assertFalse(iterator.advance()); // Advance the iterator again to record bytes read. |
| |
| // Test that the copy works two ways: 1) if we call hasNext, and 2) if we don't. |
| assertTrue(values1Copy1.hasNext()); |
| assertEquals(expectedEntry, values1Copy1.next()); |
| |
| assertEquals(expectedEntry, values1Copy2.next()); |
| |
| Map<String, Long> expectedReadBytesMap = new HashMap<>(); |
| expectedReadBytesMap.put(MOCK_ORIGINAL_NAME_FOR_EXECUTING_STEP1, 15L); |
| |
| // Verify that each executing step used when reading from the GroupingShuffleReader |
| // has a counter with a bytes read value. |
| assertEquals(expectedReadBytesMap.size(), (long) shuffleReadCounter.counterSet.size()); |
| Iterator it = expectedReadBytesMap.entrySet().iterator(); |
| while (it.hasNext()) { |
| Map.Entry<String, Long> pair = (Map.Entry) it.next(); |
| Counter counter = |
| shuffleReadCounter.counterSet.getExistingCounter( |
| ShuffleReadCounter.generateCounterName(ORIGINAL_SHUFFLE_STEP_NAME, pair.getKey())); |
| assertEquals(pair.getValue(), counter.getAggregate()); |
| } |
| } |
| |
| /** A ShuffleEntryReader that asserts that its iterators never go backwards ("reiterate"). */ |
| private static class ForwardOnlyShuffleEntryReader implements ShuffleEntryReader { |
| private final List<ShuffleEntry> entries; |
| private int minNextIndex = 0; // The smallest index that iterators are allowed to advance. |
| |
| public ForwardOnlyShuffleEntryReader(List<ShuffleEntry> entries) { |
| this.entries = entries; |
| } |
| |
| @Override |
| public Reiterator<ShuffleEntry> read( |
| @Nullable ShufflePosition startPosition, @Nullable ShufflePosition endPosition) { |
| return new MinIndexAssertingListReiterator(entries, 0); |
| } |
| |
| @Override |
| public void close() {} |
| |
| private class MinIndexAssertingListReiterator<T> extends ListReiterator<T> { |
| public MinIndexAssertingListReiterator(List<T> entries, int nextIndex) { |
| super(entries, nextIndex); |
| } |
| |
| @Override |
| public Reiterator<T> copy() { |
| return new MinIndexAssertingListReiterator<T>(entries, nextIndex); |
| } |
| |
| @Override |
| public T next() { |
| assertTrue("Reiteration unexpected.", nextIndex >= minNextIndex); |
| minNextIndex = nextIndex; |
| return super.next(); |
| } |
| } |
| } |
| |
| /** |
| * Tests that GroupingShuffleEntryIterator does not reiterate the underlying shuffle iterator when |
| * the returned value iterators are iterated over (i.e., that fast-forwarding works properly). |
| */ |
| @Test |
| public void testNoReiteration() throws Exception { |
| ArrayList<ShuffleEntry> entries = new ArrayList<>(); |
| entries.add(shuffleEntry("k1", "v11")); |
| entries.add(shuffleEntry("k1", "v12")); |
| entries.add(shuffleEntry("k1", "v13")); |
| entries.add(shuffleEntry("k2", "v21")); |
| entries.add(shuffleEntry("k2", "v22")); |
| entries.add(shuffleEntry("k2", "v23")); |
| ForwardOnlyShuffleEntryReader reader = new ForwardOnlyShuffleEntryReader(entries); |
| |
| iterator = |
| new GroupingShuffleEntryIterator(reader, START_POSITION, END_POSITION) { |
| @Override |
| protected void notifyElementRead(long byteSize) { |
| // nothing |
| } |
| |
| @Override |
| protected void commitBytesRead(long bytes) { |
| // nothing |
| } |
| }; |
| int totalKeys = 0; |
| int totalValues = 0; |
| while (iterator.advance()) { |
| ++totalKeys; |
| Reiterator<ShuffleEntry> values = iterator.getCurrent().values.iterator(); |
| while (values.hasNext()) { |
| values.next(); |
| ++totalValues; |
| } |
| } |
| assertEquals(2, totalKeys); |
| assertEquals(6, totalValues); |
| |
| // We expect that AssertionException in MinIndexAssertingListReiterator.next() is not thrown. |
| } |
| } |