blob: 722aaa57d313a2b3fbdbde8cf43244af715f8bf4 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.app.rm;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.yarn.client.api.async.impl.AMRMClientAsyncImpl;
import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.rm.TaskScheduler.CookieContainerRequest;
import org.apache.tez.dag.app.rm.TaskScheduler.TaskSchedulerAppCallback;
class TestTaskSchedulerHelpers {
// Mocking AMRMClientImpl to make use of getMatchingRequest
static class AMRMClientForTest extends AMRMClientImpl<CookieContainerRequest> {
@Override
protected void serviceStart() {
}
@Override
protected void serviceStop() {
}
}
// Mocking AMRMClientAsyncImpl to make use of getMatchingRequest
static class AMRMClientAsyncForTest extends
AMRMClientAsyncImpl<CookieContainerRequest> {
public AMRMClientAsyncForTest(
AMRMClient<CookieContainerRequest> client,
int intervalMs) {
// CallbackHandler is not needed - will be called independently in the test.
super(client, intervalMs, null);
}
@SuppressWarnings("unchecked")
@Override
public RegisterApplicationMasterResponse registerApplicationMaster(
String appHostName, int appHostPort, String appTrackingUrl) {
RegisterApplicationMasterResponse mockRegResponse = mock(RegisterApplicationMasterResponse.class);
Resource mockMaxResource = mock(Resource.class);
Map<ApplicationAccessType, String> mockAcls = mock(Map.class);
when(mockRegResponse.getMaximumResourceCapability()).thenReturn(
mockMaxResource);
when(mockRegResponse.getApplicationACLs()).thenReturn(mockAcls);
return mockRegResponse;
}
@Override
public void unregisterApplicationMaster(FinalApplicationStatus appStatus,
String appMessage, String appTrackingUrl) {
}
@Override
protected void serviceStart() {
}
@Override
protected void serviceStop() {
}
}
// Overrides start / stop. Will be controlled without the extra event handling thread.
static class TaskSchedulerEventHandlerForTest extends TaskSchedulerEventHandler {
private AMRMClientAsync<CookieContainerRequest> amrmClientAsync;
@SuppressWarnings("rawtypes")
public TaskSchedulerEventHandlerForTest(AppContext appContext,
EventHandler eventHandler, AMRMClientAsync<CookieContainerRequest> amrmClientAsync) {
super(appContext, null, eventHandler);
this.amrmClientAsync = amrmClientAsync;
}
@Override
public TaskScheduler createTaskScheduler(String host, int port, String trackingUrl) {
return new TaskSchedulerWithDrainableAppCallback(this, host, port, trackingUrl, amrmClientAsync);
}
public TaskScheduler getSpyTaskScheduler() {
return this.taskScheduler;
}
@Override
public void serviceStart() {
TaskScheduler taskSchedulerReal = createTaskScheduler("host", 0, "");
// Init the service so that reuse configuration is picked up.
taskSchedulerReal.serviceInit(getConfig());
taskSchedulerReal.serviceStart();
taskScheduler = spy(taskSchedulerReal);
}
@Override
public void serviceStop() {
}
}
@SuppressWarnings("rawtypes")
static class CapturingEventHandler implements EventHandler {
private List<Event> events = new LinkedList<Event>();
public void handle(Event event) {
events.add(event);
}
public void reset() {
events.clear();
}
public void verifyNoInvocations(Class<? extends Event> eventClass) {
for (Event e : events) {
assertFalse(e.getClass().getName().equals(eventClass.getName()));
}
}
public void verifyInvocation(Class<? extends Event> eventClass) {
for (Event e : events) {
if (e.getClass().getName().equals(eventClass.getName())) {
return;
}
}
fail("Expected Event: " + eventClass.getName() + " not sent");
}
}
static class TaskSchedulerWithDrainableAppCallback extends TaskScheduler {
private TaskSchedulerAppCallbackDrainable drainableAppCallback;
public TaskSchedulerWithDrainableAppCallback(
TaskSchedulerAppCallback appClient, String appHostName,
int appHostPort, String appTrackingUrl) {
super(appClient, appHostName, appHostPort, appTrackingUrl);
}
public TaskSchedulerWithDrainableAppCallback(
TaskSchedulerAppCallback appClient, String appHostName,
int appHostPort, String appTrackingUrl,
AMRMClientAsync<CookieContainerRequest> client) {
super(appClient, appHostName, appHostPort, appTrackingUrl, client);
}
@Override
TaskSchedulerAppCallback createAppCallbackDelegate(
TaskSchedulerAppCallback realAppClient) {
drainableAppCallback = new TaskSchedulerAppCallbackDrainable(
new TaskSchedulerAppCallbackWrapper(realAppClient,
appCallbackExecutor));
return drainableAppCallback;
}
public TaskSchedulerAppCallbackDrainable getDrainableAppCallback() {
return drainableAppCallback;
}
}
@SuppressWarnings("rawtypes")
static class TaskSchedulerAppCallbackDrainable implements TaskSchedulerAppCallback {
int completedEvents;
int invocations;
private TaskSchedulerAppCallback real;
private CompletionService completionService;
public TaskSchedulerAppCallbackDrainable(TaskSchedulerAppCallbackWrapper real) {
completionService = real.completionService;
this.real = real;
}
@Override
public void taskAllocated(Object task, Object appCookie, Container container) {
invocations++;
real.taskAllocated(task, appCookie, container);
}
@Override
public void containerCompleted(Object taskLastAllocated,
ContainerStatus containerStatus) {
invocations++;
real.containerCompleted(taskLastAllocated, containerStatus);
}
@Override
public void containerBeingReleased(ContainerId containerId) {
invocations++;
real.containerBeingReleased(containerId);
}
@Override
public void nodesUpdated(List<NodeReport> updatedNodes) {
invocations++;
real.nodesUpdated(updatedNodes);
}
@Override
public void appShutdownRequested() {
invocations++;
real.appShutdownRequested();
}
@Override
public void setApplicationRegistrationData(Resource maxContainerCapability,
Map<ApplicationAccessType, String> appAcls) {
invocations++;
real.setApplicationRegistrationData(maxContainerCapability, appAcls);
}
@Override
public void onError(Throwable t) {
invocations++;
real.onError(t);
}
@Override
public float getProgress() {
invocations++;
return real.getProgress();
}
@Override
public AppFinalStatus getFinalAppStatus() {
invocations++;
return real.getFinalAppStatus();
}
public void drain() throws InterruptedException, ExecutionException {
while (completedEvents < invocations) {
Future f = completionService.poll(5000l, TimeUnit.MILLISECONDS);
if (f != null) {
completedEvents++;
} else {
fail("Timed out while trying to drain queue");
}
}
}
}
}