blob: 2cabb276aac2222de5d2f93c57342dddbbbc2a60 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.app.rm;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.common.ContainerSignatureMatcher;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService;
import org.apache.tez.dag.app.rm.YarnTaskSchedulerService.CookieContainerRequest;
import org.apache.tez.hadoop.shim.HadoopShimsLoader;
import org.apache.tez.serviceplugins.api.DagInfo;
import org.apache.tez.serviceplugins.api.ServicePluginError;
import org.apache.tez.serviceplugins.api.TaskScheduler;
import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
class TestTaskSchedulerHelpers {
// Mocking AMRMClientImpl to make use of getMatchingRequest
static class AMRMClientForTest extends AMRMClientImpl<CookieContainerRequest> {
@Override
protected void serviceStart() {
}
@Override
protected void serviceStop() {
}
}
// Mocking AMRMClientAsyncImpl to make use of getMatchingRequest
static class AMRMClientAsyncForTest extends
TezAMRMClientAsync<CookieContainerRequest> {
public AMRMClientAsyncForTest(
AMRMClient<CookieContainerRequest> client,
int intervalMs) {
// CallbackHandler is not needed - will be called independently in the test.
super(client, intervalMs, null);
}
@SuppressWarnings("unchecked")
@Override
public RegisterApplicationMasterResponse registerApplicationMaster(
String appHostName, int appHostPort, String appTrackingUrl) {
RegisterApplicationMasterResponse mockRegResponse = mock(RegisterApplicationMasterResponse.class);
Resource mockMaxResource = mock(Resource.class);
Map<ApplicationAccessType, String> mockAcls = mock(Map.class);
when(mockRegResponse.getMaximumResourceCapability()).thenReturn(
mockMaxResource);
when(mockRegResponse.getApplicationACLs()).thenReturn(mockAcls);
return mockRegResponse;
}
@Override
public void unregisterApplicationMaster(FinalApplicationStatus appStatus,
String appMessage, String appTrackingUrl) {
}
@Override
protected void serviceStart() {
}
@Override
protected void serviceStop() {
}
}
// Overrides start / stop. Will be controlled without the extra event handling thread.
static class TaskSchedulerManagerForTest extends
TaskSchedulerManager {
private TezAMRMClientAsync<CookieContainerRequest> amrmClientAsync;
private ContainerSignatureMatcher containerSignatureMatcher;
private UserPayload defaultPayload;
@SuppressWarnings("rawtypes")
public TaskSchedulerManagerForTest(AppContext appContext,
EventHandler eventHandler,
TezAMRMClientAsync<CookieContainerRequest> amrmClientAsync,
ContainerSignatureMatcher containerSignatureMatcher,
UserPayload defaultPayload) {
super(appContext, null, eventHandler, containerSignatureMatcher, null,
Lists.newArrayList(new NamedEntityDescriptor("FakeScheduler", null)),
false, new HadoopShimsLoader(appContext.getAMConf()).getHadoopShim());
this.amrmClientAsync = amrmClientAsync;
this.containerSignatureMatcher = containerSignatureMatcher;
this.defaultPayload = defaultPayload;
}
@SuppressWarnings("unchecked")
@Override
public void instantiateSchedulers(String host, int port, String trackingUrl,
AppContext appContext) {
TaskSchedulerContext taskSchedulerContext =
new TaskSchedulerContextImpl(this, appContext, 0, trackingUrl, 1000, host, port,
defaultPayload);
TaskSchedulerContextImplWrapper wrapper =
new TaskSchedulerContextImplWrapper(taskSchedulerContext,
new CountingExecutorService(appCallbackExecutor));
TaskSchedulerContextDrainable drainable = new TaskSchedulerContextDrainable(wrapper);
taskSchedulers[0] = new TaskSchedulerWrapper(
spy(new TaskSchedulerWithDrainableContext(drainable, amrmClientAsync)));
taskSchedulerServiceWrappers[0] =
new ServicePluginLifecycleAbstractService(taskSchedulers[0].getTaskScheduler());
}
public TaskScheduler getSpyTaskScheduler() {
return taskSchedulers[0].getTaskScheduler();
}
@Override
public void serviceStart() {
instantiateSchedulers("host", 0, "", appContext);
// Init the service so that reuse configuration is picked up.
taskSchedulerServiceWrappers[0].init(getConfig());
taskSchedulerServiceWrappers[0].start();
}
@Override
public void serviceStop() {
}
}
@SuppressWarnings("rawtypes")
static class CapturingEventHandler implements EventHandler {
private Queue<Event> events = new ConcurrentLinkedQueue<Event>();
public void handle(Event event) {
events.add(event);
}
public void reset() {
events.clear();
}
public void verifyNoInvocations(Class<? extends Event> eventClass) {
for (Event e : events) {
assertFalse(e.getClass().getName().equals(eventClass.getName()));
}
}
public Event verifyInvocation(Class<? extends Event> eventClass) {
for (Event e : events) {
if (e.getClass().getName().equals(eventClass.getName())) {
return e;
}
}
fail("Expected Event: " + eventClass.getName() + " not sent");
return null;
}
}
static class TaskSchedulerWithDrainableContext extends YarnTaskSchedulerService {
public TaskSchedulerWithDrainableContext(
TaskSchedulerContextDrainable appClient,
TezAMRMClientAsync<CookieContainerRequest> client) {
super(appClient, client);
shouldUnregister.set(true);
}
public TaskSchedulerContextDrainable getDrainableAppCallback() {
return (TaskSchedulerContextDrainable)getContext();
}
}
@SuppressWarnings("rawtypes")
static class TaskSchedulerContextDrainable implements TaskSchedulerContext {
int completedEvents;
int invocations;
private TaskSchedulerContext real;
private CountingExecutorService countingExecutorService;
final AtomicInteger count = new AtomicInteger(0);
public TaskSchedulerContextDrainable(TaskSchedulerContextImplWrapper real) {
countingExecutorService = (CountingExecutorService) real.getExecutorService();
this.real = real;
}
@Override
public void taskAllocated(Object task, Object appCookie, Container container) {
count.incrementAndGet();
invocations++;
real.taskAllocated(task, appCookie, container);
}
@Override
public void containerCompleted(Object taskLastAllocated,
ContainerStatus containerStatus) {
invocations++;
real.containerCompleted(taskLastAllocated, containerStatus);
}
@Override
public void containerBeingReleased(ContainerId containerId) {
invocations++;
real.containerBeingReleased(containerId);
}
@Override
public void nodesUpdated(List<NodeReport> updatedNodes) {
invocations++;
real.nodesUpdated(updatedNodes);
}
@Override
public void appShutdownRequested() {
invocations++;
real.appShutdownRequested();
}
@Override
public void setApplicationRegistrationData(Resource maxContainerCapability,
Map<ApplicationAccessType, String> appAcls, ByteBuffer key) {
invocations++;
real.setApplicationRegistrationData(maxContainerCapability, appAcls, key);
}
@Override
public void reportError(@Nonnull ServicePluginError servicePluginError, String message,
DagInfo dagInfo) {
invocations++;
real.reportError(servicePluginError, message, dagInfo);
}
@Override
public float getProgress() {
invocations++;
return real.getProgress();
}
@Override
public AppFinalStatus getFinalAppStatus() {
invocations++;
return real.getFinalAppStatus();
}
// Not incrementing invocations for methods which to not obtain locks,
// and do not go via the executor service.
@Override
public UserPayload getInitialUserPayload() {
return real.getInitialUserPayload();
}
@Override
public String getAppTrackingUrl() {
return real.getAppTrackingUrl();
}
@Override
public long getCustomClusterIdentifier() {
return real.getCustomClusterIdentifier();
}
@Override
public ContainerSignatureMatcher getContainerSignatureMatcher() {
return real.getContainerSignatureMatcher();
}
@Override
public ApplicationAttemptId getApplicationAttemptId() {
return real.getApplicationAttemptId();
}
@Nullable
@Override
public DagInfo getCurrentDagInfo() {
return real.getCurrentDagInfo();
}
@Override
public String getAppHostName() {
return real.getAppHostName();
}
@Override
public int getAppClientPort() {
return real.getAppClientPort();
}
@Override
public boolean isSession() {
return real.isSession();
}
@Override
public AMState getAMState() {
return real.getAMState();
}
@Override
public void preemptContainer(ContainerId cId) {
invocations++;
real.preemptContainer(cId);
}
public void drain() throws InterruptedException, ExecutionException {
while (completedEvents < invocations) {
Future f = countingExecutorService.completionService.poll(5000l, TimeUnit.MILLISECONDS);
if (f != null) {
completedEvents++;
} else {
fail("Timed out while trying to drain queue");
}
}
}
}
static class AlwaysMatchesContainerMatcher implements ContainerSignatureMatcher {
@Override
public boolean isSuperSet(Object cs1, Object cs2) {
Preconditions.checkNotNull(cs1, "Arguments cannot be null");
Preconditions.checkNotNull(cs2, "Arguments cannot be null");
return true;
}
@Override
public boolean isExactMatch(Object cs1, Object cs2) {
return true;
}
@Override
public Map<String, LocalResource> getAdditionalResources(Map<String, LocalResource> lr1,
Map<String, LocalResource> lr2) {
return Maps.newHashMap();
}
@Override
public Object union(Object cs1, Object cs2) {
return cs1;
}
}
static class PreemptionMatcher implements ContainerSignatureMatcher {
@Override
public boolean isSuperSet(Object cs1, Object cs2) {
Preconditions.checkNotNull(cs1, "Arguments cannot be null");
Preconditions.checkNotNull(cs2, "Arguments cannot be null");
return true;
}
@Override
public boolean isExactMatch(Object cs1, Object cs2) {
if (cs1 == cs2 && cs1 != null) {
return true;
}
return false;
}
@Override
public Map<String, LocalResource> getAdditionalResources(Map<String, LocalResource> lr1,
Map<String, LocalResource> lr2) {
return Maps.newHashMap();
}
@Override
public Object union(Object cs1, Object cs2) {
return cs1;
}
}
static void waitForDelayedDrainNotify(AtomicBoolean drainNotifier)
throws InterruptedException {
synchronized (drainNotifier) {
while (!drainNotifier.get()) {
drainNotifier.wait();
}
}
}
static CountingExecutorService createCountingExecutingService(ExecutorService rawExecutor) {
return new CountingExecutorService(rawExecutor);
}
@SuppressWarnings({"rawtypes", "unchecked"})
private static class CountingExecutorService implements ExecutorService {
final ExecutorService real;
final CompletionService completionService;
CountingExecutorService(ExecutorService real) {
this.real = real;
completionService = new ExecutorCompletionService(real);
}
@Override
public void execute(Runnable command) {
throw new UnsupportedOperationException("Not expected to be used");
}
@Override
public void shutdown() {
real.shutdown();
}
@Override
public List<Runnable> shutdownNow() {
return real.shutdownNow();
}
@Override
public boolean isShutdown() {
return real.isShutdown();
}
@Override
public boolean isTerminated() {
return real.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return real.awaitTermination(timeout, unit);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return completionService.submit(task);
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return completionService.submit(task, result);
}
@Override
public Future<?> submit(Runnable task) {
throw new UnsupportedOperationException("Not expected to be used");
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks)
throws InterruptedException {
throw new UnsupportedOperationException("Not expected to be used");
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout,
TimeUnit unit) throws InterruptedException {
throw new UnsupportedOperationException("Not expected to be used");
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException,
ExecutionException {
throw new UnsupportedOperationException("Not expected to be used");
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
throw new UnsupportedOperationException("Not expected to be used");
}
}
static TaskSchedulerContext setupMockTaskSchedulerContext(String appHost, int appPort,
String appUrl, Configuration conf) {
return setupMockTaskSchedulerContext(appHost, appPort, appUrl, false, conf);
}
static TaskSchedulerContext setupMockTaskSchedulerContext(String appHost, int appPort,
String appUrl, boolean isSession,
Configuration conf) {
return setupMockTaskSchedulerContext(appHost, appPort, appUrl, isSession, null, null, null,
conf);
}
static TaskSchedulerContext setupMockTaskSchedulerContext(String appHost, int appPort,
String appUrl, boolean isSession,
ApplicationAttemptId appAttemptId,
Long customAppIdentifier,
ContainerSignatureMatcher containerSignatureMatcher,
Configuration conf) {
TaskSchedulerContext mockContext = mock(TaskSchedulerContext.class);
when(mockContext.getAppHostName()).thenReturn(appHost);
when(mockContext.getAppClientPort()).thenReturn(appPort);
when(mockContext.getAppTrackingUrl()).thenReturn(appUrl);
when(mockContext.getAMState()).thenReturn(TaskSchedulerContext.AMState.RUNNING_APP);
UserPayload userPayload;
try {
userPayload = TezUtils.createUserPayloadFromConf(conf);
} catch (IOException e) {
throw new TezUncheckedException(e);
}
when(mockContext.getInitialUserPayload()).thenReturn(userPayload);
when(mockContext.isSession()).thenReturn(isSession);
if (containerSignatureMatcher != null) {
when(mockContext.getContainerSignatureMatcher())
.thenReturn(containerSignatureMatcher);
} else {
when(mockContext.getContainerSignatureMatcher())
.thenReturn(new AlwaysMatchesContainerMatcher());
}
if (appAttemptId != null) {
when(mockContext.getApplicationAttemptId()).thenReturn(appAttemptId);
}
if (customAppIdentifier != null) {
when(mockContext.getCustomClusterIdentifier()).thenReturn(customAppIdentifier);
}
return mockContext;
}
}