blob: 041fd038543a65ae80a4b0d0cecbcfc9a31c7377 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.library.common.shuffle.impl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.security.token.Token;
import org.apache.tez.common.TezExecutors;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezSharedExecutor;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.ExecutionContext;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
import org.apache.tez.runtime.library.common.shuffle.Fetcher;
import org.apache.tez.runtime.library.common.shuffle.InputAttemptFetchFailure;
import org.apache.tez.runtime.library.common.shuffle.FetchResult;
import org.apache.tez.runtime.library.common.shuffle.InputHost;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
public class TestShuffleManager {
private static final String FETCHER_HOST = "localhost";
private static final int PORT = 8080;
private static final String PATH_COMPONENT = "attempttmp";
private final Configuration conf = new Configuration();
private TezExecutors sharedExecutor;
@Before
public void setup() {
sharedExecutor = new TezSharedExecutor(conf);
}
@After
public void cleanup() {
sharedExecutor.shutdownNow();
}
/**
* One reducer fetches multiple partitions from each mapper.
* For a given mapper, the reducer sends DataMovementEvents for several
* partitions, wait for some time and then send DataMovementEvents for the
* rest of the partitions. Then do the same thing for the next mapper.
* Verify ShuffleManager is able to get all the events.
*/
@Test(timeout = 50000)
public void testMultiplePartitions() throws Exception {
final int numOfMappers = 3;
final int numOfPartitions = 5;
final int firstPart = 2;
InputContext inputContext = createInputContext();
ShuffleManagerForTest shuffleManager = createShuffleManager(inputContext,
numOfMappers * numOfPartitions);
FetchedInputAllocator inputAllocator = mock(FetchedInputAllocator.class);
ShuffleInputEventHandlerImpl handler = new ShuffleInputEventHandlerImpl(
inputContext, shuffleManager, inputAllocator, null, false, 0, false);
shuffleManager.run();
List<Event> eventList = new LinkedList<Event>();
int targetIndex = 0; // The physical input index within the reduce task
for (int i = 0; i < numOfMappers; i++) {
String mapperHost = "host" + i;
int srcIndex = 20; // The physical output index within the map task
// Send the first batch of DataMovementEvents
eventList.clear();
for (int j = 0; j < firstPart; j++) {
Event dme = createDataMovementEvent(mapperHost, srcIndex++,
targetIndex++);
eventList.add(dme);
}
handler.handleEvents(eventList);
Thread.sleep(500);
// Send the second batch of DataMovementEvents
eventList.clear();
for (int j = 0; j < numOfPartitions - firstPart; j++) {
Event dme = createDataMovementEvent(mapperHost, srcIndex++,
targetIndex++);
eventList.add(dme);
}
handler.handleEvents(eventList);
}
int waitCount = 100;
while (waitCount-- > 0 &&
!(shuffleManager.isFetcherExecutorShutdown() &&
numOfMappers * numOfPartitions ==
shuffleManager.getNumOfCompletedInputs())) {
Thread.sleep(100);
}
assertTrue(shuffleManager.isFetcherExecutorShutdown());
assertEquals(numOfMappers * numOfPartitions,
shuffleManager.getNumOfCompletedInputs());
}
private InputContext createInputContext() throws IOException {
DataOutputBuffer port_dob = new DataOutputBuffer();
port_dob.writeInt(PORT);
final ByteBuffer shuffleMetaData = ByteBuffer.wrap(port_dob.getData(), 0,
port_dob.getLength());
port_dob.close();
ExecutionContext executionContext = mock(ExecutionContext.class);
doReturn(FETCHER_HOST).when(executionContext).getHostName();
InputContext inputContext = mock(InputContext.class);
doReturn(new TezCounters()).when(inputContext).getCounters();
doReturn("sourceVertex").when(inputContext).getSourceVertexName();
doReturn(shuffleMetaData).when(inputContext)
.getServiceProviderMetaData(conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT));
doReturn(executionContext).when(inputContext).getExecutionContext();
when(inputContext.createTezFrameworkExecutorService(anyInt(), anyString())).thenAnswer(
new Answer<ExecutorService>() {
@Override
public ExecutorService answer(InvocationOnMock invocation) throws Throwable {
return sharedExecutor.createExecutorService(
invocation.getArgumentAt(0, Integer.class),
invocation.getArgumentAt(1, String.class));
}
});
return inputContext;
}
@Test(timeout=5000)
public void testUseSharedExecutor() throws Exception {
InputContext inputContext = createInputContext();
createShuffleManager(inputContext, 2);
verify(inputContext, times(0)).createTezFrameworkExecutorService(anyInt(), anyString());
inputContext = createInputContext();
conf.setBoolean(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_FETCHER_USE_SHARED_POOL, true);
createShuffleManager(inputContext, 2);
verify(inputContext).createTezFrameworkExecutorService(anyInt(), anyString());
}
@Test (timeout = 20000)
public void testProgressWithEmptyPendingHosts() throws Exception {
InputContext inputContext = createInputContext();
final ShuffleManager shuffleManager = spy(createShuffleManager(inputContext, 1));
Thread schedulerGetHostThread = new Thread(new Runnable() {
@Override
public void run() {
try {
shuffleManager.run();
} catch (Exception e) {
e.printStackTrace();
}
}
});
schedulerGetHostThread.start();
Thread.currentThread().sleep(1000 * 3 + 1000);
schedulerGetHostThread.interrupt();
verify(inputContext, atLeast(3)).notifyProgress();
}
@Test (timeout = 200000)
public void testFetchFailed() throws Exception {
InputContext inputContext = createInputContext();
final ShuffleManager shuffleManager = spy(createShuffleManager(inputContext, 1));
Thread schedulerGetHostThread = new Thread(new Runnable() {
@Override
public void run() {
try {
shuffleManager.run();
} catch (Exception e) {
e.printStackTrace();
}
}
});
InputAttemptFetchFailure inputAttemptFetchFailure =
new InputAttemptFetchFailure(new InputAttemptIdentifier(1, 1));
schedulerGetHostThread.start();
Thread.sleep(1000);
shuffleManager.fetchFailed("host1", inputAttemptFetchFailure, false);
Thread.sleep(1000);
ArgumentCaptor<List> captor = ArgumentCaptor.forClass(List.class);
verify(inputContext, times(1))
.sendEvents(captor.capture());
Assert.assertEquals("Size was: " + captor.getAllValues().size(),
captor.getAllValues().size(), 1);
List<Event> capturedList = captor.getAllValues().get(0);
Assert.assertEquals("Size was: " + capturedList.size(),
capturedList.size(), 1);
InputReadErrorEvent inputEvent = (InputReadErrorEvent)capturedList.get(0);
Assert.assertEquals("Number of failures was: " + inputEvent.getNumFailures(),
inputEvent.getNumFailures(), 1);
shuffleManager.fetchFailed("host1", inputAttemptFetchFailure, false);
shuffleManager.fetchFailed("host1", inputAttemptFetchFailure, false);
Thread.sleep(1000);
verify(inputContext, times(1)).sendEvents(any());
// Wait more than five seconds for the batch to go out
Thread.sleep(5000);
captor = ArgumentCaptor.forClass(List.class);
verify(inputContext, times(2))
.sendEvents(captor.capture());
Assert.assertEquals("Size was: " + captor.getAllValues().size(),
captor.getAllValues().size(), 2);
capturedList = captor.getAllValues().get(1);
Assert.assertEquals("Size was: " + capturedList.size(),
capturedList.size(), 1);
inputEvent = (InputReadErrorEvent)capturedList.get(0);
Assert.assertEquals("Number of failures was: " + inputEvent.getNumFailures(),
inputEvent.getNumFailures(), 2);
schedulerGetHostThread.interrupt();
}
private ShuffleManagerForTest createShuffleManager(
InputContext inputContext, int expectedNumOfPhysicalInputs)
throws IOException {
Path outDirBase = new Path(".", "outDir");
String[] outDirs = new String[] { outDirBase.toString() };
doReturn(outDirs).when(inputContext).getWorkDirs();
conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS,
inputContext.getWorkDirs());
// 5 seconds
conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_SHUFFLE_BATCH_WAIT, 5000);
DataOutputBuffer out = new DataOutputBuffer();
Token<JobTokenIdentifier> token = new Token<JobTokenIdentifier>(new JobTokenIdentifier(),
new JobTokenSecretManager(null));
token.write(out);
doReturn(ByteBuffer.wrap(out.getData())).when(inputContext).
getServiceConsumerMetaData(
conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT));
FetchedInputAllocator inputAllocator = mock(FetchedInputAllocator.class);
return new ShuffleManagerForTest(inputContext, conf,
expectedNumOfPhysicalInputs, 1024, false, -1, null, inputAllocator);
}
private Event createDataMovementEvent(String host, int srcIndex, int targetIndex) {
DataMovementEventPayloadProto.Builder builder =
DataMovementEventPayloadProto.newBuilder();
builder.setHost(host);
builder.setPort(PORT);
builder.setPathComponent(PATH_COMPONENT);
Event dme = DataMovementEvent
.create(srcIndex, targetIndex, 0,
builder.build().toByteString().asReadOnlyByteBuffer());
return dme;
}
private static class ShuffleManagerForTest extends ShuffleManager {
public ShuffleManagerForTest(InputContext inputContext, Configuration conf,
int numInputs, int bufferSize, boolean ifileReadAheadEnabled,
int ifileReadAheadLength, CompressionCodec codec,
FetchedInputAllocator inputAllocator) throws IOException {
super(inputContext, conf, numInputs, bufferSize, ifileReadAheadEnabled,
ifileReadAheadLength, codec, inputAllocator);
}
@Override
Fetcher constructFetcherForHost(InputHost inputHost, Configuration conf) {
final Fetcher fetcher = spy(super.constructFetcherForHost(inputHost,
conf));
final FetchResult mockFetcherResult = mock(FetchResult.class);
try {
doAnswer(new Answer<FetchResult>() {
@Override
public FetchResult answer(InvocationOnMock invocation) throws Throwable {
for(InputAttemptIdentifier input : fetcher.getSrcAttempts()) {
ShuffleManagerForTest.this.fetchSucceeded(
fetcher.getHost(), input, new TestFetchedInput(input), 0, 0,
0);
}
return mockFetcherResult;
}
}).when(fetcher).callInternal();
} catch (Exception e) {
//ignore
}
return fetcher;
}
public int getNumOfCompletedInputs() {
return completedInputSet.cardinality();
}
boolean isFetcherExecutorShutdown() {
return fetcherExecutor.isShutdown();
}
}
/**
* Fake input that is added to the completed input list in case an input does not have any data.
*
*/
@VisibleForTesting
static class TestFetchedInput extends FetchedInput {
public TestFetchedInput(InputAttemptIdentifier inputAttemptIdentifier) {
super(inputAttemptIdentifier, null);
}
@Override
public long getSize() {
return -1;
}
@Override
public Type getType() {
return Type.MEMORY;
}
@Override
public OutputStream getOutputStream() throws IOException {
return null;
}
@Override
public InputStream getInputStream() throws IOException {
return null;
}
@Override
public void commit() throws IOException {
}
@Override
public void abort() throws IOException {
}
@Override
public void free() {
}
}
}