| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.dag.app.rm; |
| |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.fail; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.spy; |
| import static org.mockito.Mockito.when; |
| |
| import javax.annotation.Nonnull; |
| import javax.annotation.Nullable; |
| import java.io.IOException; |
| import java.nio.ByteBuffer; |
| import java.util.Collection; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Queue; |
| import java.util.concurrent.Callable; |
| import java.util.concurrent.CompletionService; |
| import java.util.concurrent.ConcurrentLinkedQueue; |
| import java.util.concurrent.ExecutionException; |
| import java.util.concurrent.ExecutorCompletionService; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Future; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.TimeoutException; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import com.google.common.base.Preconditions; |
| import com.google.common.collect.Lists; |
| import com.google.common.collect.Maps; |
| import org.apache.hadoop.conf.Configuration; |
| import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; |
| import org.apache.hadoop.yarn.api.records.ApplicationAccessType; |
| import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; |
| import org.apache.hadoop.yarn.api.records.Container; |
| import org.apache.hadoop.yarn.api.records.ContainerId; |
| import org.apache.hadoop.yarn.api.records.ContainerStatus; |
| import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; |
| import org.apache.hadoop.yarn.api.records.LocalResource; |
| import org.apache.hadoop.yarn.api.records.NodeReport; |
| import org.apache.hadoop.yarn.api.records.Resource; |
| import org.apache.hadoop.yarn.client.api.AMRMClient; |
| import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl; |
| import org.apache.hadoop.yarn.event.Event; |
| import org.apache.hadoop.yarn.event.EventHandler; |
| import org.apache.tez.common.ContainerSignatureMatcher; |
| import org.apache.tez.common.TezUtils; |
| import org.apache.tez.dag.api.NamedEntityDescriptor; |
| import org.apache.tez.dag.api.TezUncheckedException; |
| import org.apache.tez.dag.api.UserPayload; |
| import org.apache.tez.dag.app.AppContext; |
| import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService; |
| import org.apache.tez.dag.app.rm.YarnTaskSchedulerService.CookieContainerRequest; |
| import org.apache.tez.hadoop.shim.HadoopShimsLoader; |
| import org.apache.tez.serviceplugins.api.DagInfo; |
| import org.apache.tez.serviceplugins.api.ServicePluginError; |
| import org.apache.tez.serviceplugins.api.TaskScheduler; |
| import org.apache.tez.serviceplugins.api.TaskSchedulerContext; |
| |
| class TestTaskSchedulerHelpers { |
| |
| // Mocking AMRMClientImpl to make use of getMatchingRequest |
| static class AMRMClientForTest extends AMRMClientImpl<CookieContainerRequest> { |
| |
| @Override |
| protected void serviceStart() { |
| } |
| |
| @Override |
| protected void serviceStop() { |
| } |
| } |
| |
| |
| // Mocking AMRMClientAsyncImpl to make use of getMatchingRequest |
| static class AMRMClientAsyncForTest extends |
| TezAMRMClientAsync<CookieContainerRequest> { |
| |
| public AMRMClientAsyncForTest( |
| AMRMClient<CookieContainerRequest> client, |
| int intervalMs) { |
| // CallbackHandler is not needed - will be called independently in the test. |
| super(client, intervalMs, null); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public RegisterApplicationMasterResponse registerApplicationMaster( |
| String appHostName, int appHostPort, String appTrackingUrl) { |
| RegisterApplicationMasterResponse mockRegResponse = mock(RegisterApplicationMasterResponse.class); |
| Resource mockMaxResource = mock(Resource.class); |
| Map<ApplicationAccessType, String> mockAcls = mock(Map.class); |
| when(mockRegResponse.getMaximumResourceCapability()).thenReturn( |
| mockMaxResource); |
| when(mockRegResponse.getApplicationACLs()).thenReturn(mockAcls); |
| return mockRegResponse; |
| } |
| |
| @Override |
| public void unregisterApplicationMaster(FinalApplicationStatus appStatus, |
| String appMessage, String appTrackingUrl) { |
| } |
| |
| @Override |
| protected void serviceStart() { |
| } |
| |
| @Override |
| protected void serviceStop() { |
| } |
| } |
| |
| // Overrides start / stop. Will be controlled without the extra event handling thread. |
| static class TaskSchedulerManagerForTest extends |
| TaskSchedulerManager { |
| |
| private TezAMRMClientAsync<CookieContainerRequest> amrmClientAsync; |
| private ContainerSignatureMatcher containerSignatureMatcher; |
| private UserPayload defaultPayload; |
| |
| @SuppressWarnings("rawtypes") |
| public TaskSchedulerManagerForTest(AppContext appContext, |
| EventHandler eventHandler, |
| TezAMRMClientAsync<CookieContainerRequest> amrmClientAsync, |
| ContainerSignatureMatcher containerSignatureMatcher, |
| UserPayload defaultPayload) { |
| super(appContext, null, eventHandler, containerSignatureMatcher, null, |
| Lists.newArrayList(new NamedEntityDescriptor("FakeScheduler", null)), |
| false, new HadoopShimsLoader(appContext.getAMConf()).getHadoopShim()); |
| this.amrmClientAsync = amrmClientAsync; |
| this.containerSignatureMatcher = containerSignatureMatcher; |
| this.defaultPayload = defaultPayload; |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public void instantiateSchedulers(String host, int port, String trackingUrl, |
| AppContext appContext) { |
| TaskSchedulerContext taskSchedulerContext = |
| new TaskSchedulerContextImpl(this, appContext, 0, trackingUrl, 1000, host, port, |
| defaultPayload); |
| TaskSchedulerContextImplWrapper wrapper = |
| new TaskSchedulerContextImplWrapper(taskSchedulerContext, |
| new CountingExecutorService(appCallbackExecutor)); |
| TaskSchedulerContextDrainable drainable = new TaskSchedulerContextDrainable(wrapper); |
| |
| taskSchedulers[0] = new TaskSchedulerWrapper( |
| spy(new TaskSchedulerWithDrainableContext(drainable, amrmClientAsync))); |
| taskSchedulerServiceWrappers[0] = |
| new ServicePluginLifecycleAbstractService(taskSchedulers[0].getTaskScheduler()); |
| } |
| |
| public TaskScheduler getSpyTaskScheduler() { |
| return taskSchedulers[0].getTaskScheduler(); |
| } |
| |
| @Override |
| public void serviceStart() { |
| instantiateSchedulers("host", 0, "", appContext); |
| // Init the service so that reuse configuration is picked up. |
| taskSchedulerServiceWrappers[0].init(getConfig()); |
| taskSchedulerServiceWrappers[0].start(); |
| } |
| |
| @Override |
| public void serviceStop() { |
| } |
| } |
| |
| @SuppressWarnings("rawtypes") |
| static class CapturingEventHandler implements EventHandler { |
| |
| private Queue<Event> events = new ConcurrentLinkedQueue<Event>(); |
| |
| public void handle(Event event) { |
| events.add(event); |
| } |
| |
| public void reset() { |
| events.clear(); |
| } |
| |
| public void verifyNoInvocations(Class<? extends Event> eventClass) { |
| for (Event e : events) { |
| assertFalse(e.getClass().getName().equals(eventClass.getName())); |
| } |
| } |
| |
| public Event verifyInvocation(Class<? extends Event> eventClass) { |
| for (Event e : events) { |
| if (e.getClass().getName().equals(eventClass.getName())) { |
| return e; |
| } |
| } |
| fail("Expected Event: " + eventClass.getName() + " not sent"); |
| return null; |
| } |
| } |
| |
| static class TaskSchedulerWithDrainableContext extends YarnTaskSchedulerService { |
| |
| |
| public TaskSchedulerWithDrainableContext( |
| TaskSchedulerContextDrainable appClient, |
| TezAMRMClientAsync<CookieContainerRequest> client) { |
| super(appClient, client); |
| shouldUnregister.set(true); |
| } |
| |
| public TaskSchedulerContextDrainable getDrainableAppCallback() { |
| return (TaskSchedulerContextDrainable)getContext(); |
| } |
| } |
| |
| @SuppressWarnings("rawtypes") |
| static class TaskSchedulerContextDrainable implements TaskSchedulerContext { |
| int completedEvents; |
| int invocations; |
| private TaskSchedulerContext real; |
| private CountingExecutorService countingExecutorService; |
| final AtomicInteger count = new AtomicInteger(0); |
| |
| public TaskSchedulerContextDrainable(TaskSchedulerContextImplWrapper real) { |
| countingExecutorService = (CountingExecutorService) real.getExecutorService(); |
| this.real = real; |
| } |
| |
| @Override |
| public void taskAllocated(Object task, Object appCookie, Container container) { |
| count.incrementAndGet(); |
| invocations++; |
| real.taskAllocated(task, appCookie, container); |
| } |
| |
| @Override |
| public void containerCompleted(Object taskLastAllocated, |
| ContainerStatus containerStatus) { |
| invocations++; |
| real.containerCompleted(taskLastAllocated, containerStatus); |
| } |
| |
| @Override |
| public void containerBeingReleased(ContainerId containerId) { |
| invocations++; |
| real.containerBeingReleased(containerId); |
| } |
| |
| @Override |
| public void nodesUpdated(List<NodeReport> updatedNodes) { |
| invocations++; |
| real.nodesUpdated(updatedNodes); |
| } |
| |
| @Override |
| public void appShutdownRequested() { |
| invocations++; |
| real.appShutdownRequested(); |
| } |
| |
| @Override |
| public void setApplicationRegistrationData(Resource maxContainerCapability, |
| Map<ApplicationAccessType, String> appAcls, ByteBuffer key) { |
| invocations++; |
| real.setApplicationRegistrationData(maxContainerCapability, appAcls, key); |
| } |
| |
| @Override |
| public void reportError(@Nonnull ServicePluginError servicePluginError, String message, |
| DagInfo dagInfo) { |
| invocations++; |
| real.reportError(servicePluginError, message, dagInfo); |
| } |
| |
| @Override |
| public float getProgress() { |
| invocations++; |
| return real.getProgress(); |
| } |
| |
| @Override |
| public AppFinalStatus getFinalAppStatus() { |
| invocations++; |
| return real.getFinalAppStatus(); |
| } |
| |
| // Not incrementing invocations for methods which to not obtain locks, |
| // and do not go via the executor service. |
| @Override |
| public UserPayload getInitialUserPayload() { |
| return real.getInitialUserPayload(); |
| } |
| |
| @Override |
| public String getAppTrackingUrl() { |
| return real.getAppTrackingUrl(); |
| } |
| |
| @Override |
| public long getCustomClusterIdentifier() { |
| return real.getCustomClusterIdentifier(); |
| } |
| |
| @Override |
| public ContainerSignatureMatcher getContainerSignatureMatcher() { |
| return real.getContainerSignatureMatcher(); |
| } |
| |
| @Override |
| public ApplicationAttemptId getApplicationAttemptId() { |
| return real.getApplicationAttemptId(); |
| } |
| |
| @Nullable |
| @Override |
| public DagInfo getCurrentDagInfo() { |
| return real.getCurrentDagInfo(); |
| } |
| |
| @Override |
| public String getAppHostName() { |
| return real.getAppHostName(); |
| } |
| |
| @Override |
| public int getAppClientPort() { |
| return real.getAppClientPort(); |
| } |
| |
| @Override |
| public boolean isSession() { |
| return real.isSession(); |
| } |
| |
| @Override |
| public AMState getAMState() { |
| return real.getAMState(); |
| } |
| |
| @Override |
| public void preemptContainer(ContainerId cId) { |
| invocations++; |
| real.preemptContainer(cId); |
| } |
| |
| public void drain() throws InterruptedException, ExecutionException { |
| while (completedEvents < invocations) { |
| Future f = countingExecutorService.completionService.poll(5000l, TimeUnit.MILLISECONDS); |
| if (f != null) { |
| completedEvents++; |
| } else { |
| fail("Timed out while trying to drain queue"); |
| } |
| } |
| } |
| } |
| |
| static class AlwaysMatchesContainerMatcher implements ContainerSignatureMatcher { |
| |
| @Override |
| public boolean isSuperSet(Object cs1, Object cs2) { |
| Preconditions.checkNotNull(cs1, "Arguments cannot be null"); |
| Preconditions.checkNotNull(cs2, "Arguments cannot be null"); |
| return true; |
| } |
| |
| @Override |
| public boolean isExactMatch(Object cs1, Object cs2) { |
| return true; |
| } |
| |
| @Override |
| public Map<String, LocalResource> getAdditionalResources(Map<String, LocalResource> lr1, |
| Map<String, LocalResource> lr2) { |
| return Maps.newHashMap(); |
| } |
| |
| @Override |
| public Object union(Object cs1, Object cs2) { |
| return cs1; |
| } |
| } |
| |
| static class PreemptionMatcher implements ContainerSignatureMatcher { |
| @Override |
| public boolean isSuperSet(Object cs1, Object cs2) { |
| Preconditions.checkNotNull(cs1, "Arguments cannot be null"); |
| Preconditions.checkNotNull(cs2, "Arguments cannot be null"); |
| return true; |
| } |
| |
| @Override |
| public boolean isExactMatch(Object cs1, Object cs2) { |
| if (cs1 == cs2 && cs1 != null) { |
| return true; |
| } |
| return false; |
| } |
| |
| @Override |
| public Map<String, LocalResource> getAdditionalResources(Map<String, LocalResource> lr1, |
| Map<String, LocalResource> lr2) { |
| return Maps.newHashMap(); |
| } |
| |
| @Override |
| public Object union(Object cs1, Object cs2) { |
| return cs1; |
| } |
| } |
| |
| |
| static void waitForDelayedDrainNotify(AtomicBoolean drainNotifier) |
| throws InterruptedException { |
| synchronized (drainNotifier) { |
| while (!drainNotifier.get()) { |
| drainNotifier.wait(); |
| } |
| } |
| } |
| |
| static CountingExecutorService createCountingExecutingService(ExecutorService rawExecutor) { |
| return new CountingExecutorService(rawExecutor); |
| } |
| |
| @SuppressWarnings({"rawtypes", "unchecked"}) |
| private static class CountingExecutorService implements ExecutorService { |
| |
| final ExecutorService real; |
| final CompletionService completionService; |
| |
| CountingExecutorService(ExecutorService real) { |
| this.real = real; |
| completionService = new ExecutorCompletionService(real); |
| } |
| |
| @Override |
| public void execute(Runnable command) { |
| throw new UnsupportedOperationException("Not expected to be used"); |
| } |
| |
| @Override |
| public void shutdown() { |
| real.shutdown(); |
| } |
| |
| @Override |
| public List<Runnable> shutdownNow() { |
| return real.shutdownNow(); |
| } |
| |
| @Override |
| public boolean isShutdown() { |
| return real.isShutdown(); |
| } |
| |
| @Override |
| public boolean isTerminated() { |
| return real.isTerminated(); |
| } |
| |
| @Override |
| public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { |
| return real.awaitTermination(timeout, unit); |
| } |
| |
| @Override |
| public <T> Future<T> submit(Callable<T> task) { |
| return completionService.submit(task); |
| } |
| |
| @Override |
| public <T> Future<T> submit(Runnable task, T result) { |
| return completionService.submit(task, result); |
| } |
| |
| @Override |
| public Future<?> submit(Runnable task) { |
| throw new UnsupportedOperationException("Not expected to be used"); |
| } |
| |
| @Override |
| public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) |
| throws InterruptedException { |
| throw new UnsupportedOperationException("Not expected to be used"); |
| } |
| |
| @Override |
| public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, |
| TimeUnit unit) throws InterruptedException { |
| throw new UnsupportedOperationException("Not expected to be used"); |
| } |
| |
| @Override |
| public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, |
| ExecutionException { |
| throw new UnsupportedOperationException("Not expected to be used"); |
| } |
| |
| @Override |
| public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) |
| throws InterruptedException, ExecutionException, TimeoutException { |
| throw new UnsupportedOperationException("Not expected to be used"); |
| } |
| } |
| |
| static TaskSchedulerContext setupMockTaskSchedulerContext(String appHost, int appPort, |
| String appUrl, Configuration conf) { |
| return setupMockTaskSchedulerContext(appHost, appPort, appUrl, false, conf); |
| } |
| |
| static TaskSchedulerContext setupMockTaskSchedulerContext(String appHost, int appPort, |
| String appUrl, boolean isSession, |
| Configuration conf) { |
| return setupMockTaskSchedulerContext(appHost, appPort, appUrl, isSession, null, null, null, |
| conf); |
| } |
| |
| static TaskSchedulerContext setupMockTaskSchedulerContext(String appHost, int appPort, |
| String appUrl, boolean isSession, |
| ApplicationAttemptId appAttemptId, |
| Long customAppIdentifier, |
| ContainerSignatureMatcher containerSignatureMatcher, |
| Configuration conf) { |
| |
| TaskSchedulerContext mockContext = mock(TaskSchedulerContext.class); |
| when(mockContext.getAppHostName()).thenReturn(appHost); |
| when(mockContext.getAppClientPort()).thenReturn(appPort); |
| when(mockContext.getAppTrackingUrl()).thenReturn(appUrl); |
| |
| when(mockContext.getAMState()).thenReturn(TaskSchedulerContext.AMState.RUNNING_APP); |
| UserPayload userPayload; |
| try { |
| userPayload = TezUtils.createUserPayloadFromConf(conf); |
| } catch (IOException e) { |
| throw new TezUncheckedException(e); |
| } |
| when(mockContext.getInitialUserPayload()).thenReturn(userPayload); |
| when(mockContext.isSession()).thenReturn(isSession); |
| if (containerSignatureMatcher != null) { |
| when(mockContext.getContainerSignatureMatcher()) |
| .thenReturn(containerSignatureMatcher); |
| } else { |
| when(mockContext.getContainerSignatureMatcher()) |
| .thenReturn(new AlwaysMatchesContainerMatcher()); |
| } |
| if (appAttemptId != null) { |
| when(mockContext.getApplicationAttemptId()).thenReturn(appAttemptId); |
| } |
| if (customAppIdentifier != null) { |
| when(mockContext.getCustomClusterIdentifier()).thenReturn(customAppIdentifier); |
| } |
| |
| return mockContext; |
| } |
| |
| } |