blob: c7f97d3ea9c09e612eafe04f43408aed16ff4078 [file] [log] [blame]
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.app;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import javax.annotation.Nullable;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.app.dag.event.DAGEventTerminateDag;
import org.apache.tez.dag.helpers.DagInfoImplForTest;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.serviceplugins.api.ServicePluginErrorDefaults;
import org.apache.tez.serviceplugins.api.ServicePluginException;
import org.apache.tez.serviceplugins.api.TaskCommunicator;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.apache.tez.dag.api.TezConstants;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.dag.DAG;
import org.apache.tez.dag.app.dag.event.DAGAppMasterEventType;
import org.apache.tez.dag.app.dag.event.DAGAppMasterEventUserServiceFatalError;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicatorDescriptor;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
public class TestTaskCommunicatorManager {
@Before
@After
public void resetForNextTest() {
TaskCommManagerForMultipleCommTest.reset();
}
@Test(timeout = 5000)
public void testNoTaskCommSpecified() throws IOException, TezException {
AppContext appContext = mock(AppContext.class);
TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
try {
new TaskCommManagerForMultipleCommTest(appContext, thh, chh, null);
fail("Initialization should have failed without a TaskComm specified");
} catch (IllegalArgumentException e) {
}
}
@Test(timeout = 5000)
public void testCustomTaskCommSpecified() throws IOException, TezException {
AppContext appContext = mock(AppContext.class);
TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
String customTaskCommName = "customTaskComm";
List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>();
ByteBuffer bb = ByteBuffer.allocate(4);
bb.putInt(0, 3);
UserPayload customPayload = UserPayload.create(bb);
taskCommDescriptors.add(
new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName())
.setUserPayload(customPayload));
TaskCommManagerForMultipleCommTest tcm =
new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
try {
tcm.init(new Configuration(false));
tcm.start();
assertEquals(1, tcm.getNumTaskComms());
assertFalse(tcm.getYarnTaskCommCreated());
assertFalse(tcm.getUberTaskCommCreated());
assertEquals(customTaskCommName, tcm.getTaskCommName(0));
assertEquals(bb, tcm.getTaskCommContext(0).getInitialUserPayload().getPayload());
} finally {
tcm.stop();
}
}
@Test(timeout = 5000)
public void testMultipleTaskComms() throws IOException, TezException {
AppContext appContext = mock(AppContext.class);
TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
Configuration conf = new Configuration(false);
conf.set("testkey", "testvalue");
UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
String customTaskCommName = "customTaskComm";
List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>();
ByteBuffer bb = ByteBuffer.allocate(4);
bb.putInt(0, 3);
UserPayload customPayload = UserPayload.create(bb);
taskCommDescriptors.add(
new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName())
.setUserPayload(customPayload));
taskCommDescriptors
.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload));
TaskCommManagerForMultipleCommTest tcm =
new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
try {
tcm.init(new Configuration(false));
tcm.start();
assertEquals(2, tcm.getNumTaskComms());
assertTrue(tcm.getYarnTaskCommCreated());
assertFalse(tcm.getUberTaskCommCreated());
assertEquals(customTaskCommName, tcm.getTaskCommName(0));
assertEquals(bb, tcm.getTaskCommContext(0).getInitialUserPayload().getPayload());
assertEquals(TezConstants.getTezYarnServicePluginName(), tcm.getTaskCommName(1));
Configuration confParsed = TezUtils
.createConfFromUserPayload(tcm.getTaskCommContext(1).getInitialUserPayload());
assertEquals("testvalue", confParsed.get("testkey"));
} finally {
tcm.stop();
}
}
@Test(timeout = 5000)
public void testEventRouting() throws Exception {
AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
NodeId nodeId = NodeId.newInstance("host1", 3131);
when(appContext.getAllContainers().get(any(ContainerId.class)).getContainer().getNodeId())
.thenReturn(nodeId);
TaskHeartbeatHandler thh = mock(TaskHeartbeatHandler.class);
ContainerHeartbeatHandler chh = mock(ContainerHeartbeatHandler.class);
Configuration conf = new Configuration(false);
conf.set("testkey", "testvalue");
UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
String customTaskCommName = "customTaskComm";
List<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<>();
ByteBuffer bb = ByteBuffer.allocate(4);
bb.putInt(0, 3);
UserPayload customPayload = UserPayload.create(bb);
taskCommDescriptors.add(
new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName())
.setUserPayload(customPayload));
taskCommDescriptors
.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload));
TaskCommManagerForMultipleCommTest tcm =
new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
try {
tcm.init(new Configuration(false));
tcm.start();
assertEquals(2, tcm.getNumTaskComms());
assertTrue(tcm.getYarnTaskCommCreated());
assertFalse(tcm.getUberTaskCommCreated());
verify(tcm.getTestTaskComm(0)).initialize();
verify(tcm.getTestTaskComm(0)).start();
verify(tcm.getTestTaskComm(1)).initialize();
verify(tcm.getTestTaskComm(1)).start();
ContainerId containerId1 = mock(ContainerId.class);
tcm.registerRunningContainer(containerId1, 0);
verify(tcm.getTestTaskComm(0)).registerRunningContainer(eq(containerId1), eq("host1"),
eq(3131));
ContainerId containerId2 = mock(ContainerId.class);
tcm.registerRunningContainer(containerId2, 1);
verify(tcm.getTestTaskComm(1)).registerRunningContainer(eq(containerId2), eq("host1"),
eq(3131));
} finally {
tcm.stop();
verify(tcm.getTaskCommunicator(0).getTaskCommunicator()).shutdown();
verify(tcm.getTaskCommunicator(1).getTaskCommunicator()).shutdown();
}
}
@SuppressWarnings("unchecked")
@Test(timeout = 5000)
public void testReportFailureFromTaskCommunicator() throws TezException {
String dagName = DAG_NAME;
EventHandler eventHandler = mock(EventHandler.class);
AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
doReturn("testTaskCommunicator").when(appContext).getTaskCommunicatorName(0);
doReturn(eventHandler).when(appContext).getEventHandler();
DAG dag = mock(DAG.class);
TezDAGID dagId = TezDAGID.getInstance(ApplicationId.newInstance(1, 0), DAG_INDEX);
doReturn(dagName).when(dag).getName();
doReturn(dagId).when(dag).getID();
doReturn(dag).when(appContext).getCurrentDAG();
NamedEntityDescriptor<TaskCommunicatorDescriptor> namedEntityDescriptor =
new NamedEntityDescriptor<>("testTaskCommunicator", TaskCommForFailureTest.class.getName());
List<NamedEntityDescriptor> list = new LinkedList<>();
list.add(namedEntityDescriptor);
TaskCommunicatorManager taskCommManager =
new TaskCommunicatorManager(appContext, mock(TaskHeartbeatHandler.class),
mock(ContainerHeartbeatHandler.class), list);
try {
taskCommManager.init(new Configuration());
taskCommManager.start();
taskCommManager.registerRunningContainer(mock(ContainerId.class), 0);
ArgumentCaptor<Event> argumentCaptor = ArgumentCaptor.forClass(Event.class);
verify(eventHandler, times(1)).handle(argumentCaptor.capture());
Event rawEvent = argumentCaptor.getValue();
assertTrue(rawEvent instanceof DAGEventTerminateDag);
DAGEventTerminateDag killEvent = (DAGEventTerminateDag) rawEvent;
assertTrue(killEvent.getDiagnosticInfo().contains("ReportError"));
assertTrue(killEvent.getDiagnosticInfo()
.contains(ServicePluginErrorDefaults.SERVICE_UNAVAILABLE.name()));
assertTrue(killEvent.getDiagnosticInfo().contains("[0:testTaskCommunicator]"));
reset(eventHandler);
taskCommManager.dagComplete(dag);
argumentCaptor = ArgumentCaptor.forClass(Event.class);
verify(eventHandler, times(1)).handle(argumentCaptor.capture());
rawEvent = argumentCaptor.getValue();
assertTrue(rawEvent instanceof DAGAppMasterEventUserServiceFatalError);
DAGAppMasterEventUserServiceFatalError event =
(DAGAppMasterEventUserServiceFatalError) rawEvent;
assertEquals(DAGAppMasterEventType.TASK_COMMUNICATOR_SERVICE_FATAL_ERROR, event.getType());
assertTrue(event.getDiagnosticInfo().contains("ReportedFatalError"));
assertTrue(
event.getDiagnosticInfo().contains(ServicePluginErrorDefaults.INCONSISTENT_STATE.name()));
assertTrue(event.getDiagnosticInfo().contains("[0:testTaskCommunicator]"));
} finally {
taskCommManager.stop();
}
}
@SuppressWarnings("unchecked")
@Test(timeout = 5000)
public void testTaskCommunicatorUserError() {
TaskCommunicatorContextImpl taskCommContext = mock(TaskCommunicatorContextImpl.class);
TaskCommunicator taskCommunicator = mock(TaskCommunicator.class, new ExceptionAnswer());
doReturn(taskCommContext).when(taskCommunicator).getContext();
EventHandler eventHandler = mock(EventHandler.class);
AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
when(appContext.getEventHandler()).thenReturn(eventHandler);
doReturn("testTaskCommunicator").when(appContext).getTaskCommunicatorName(0);
String expectedId = "[0:testTaskCommunicator]";
Configuration conf = new Configuration(false);
TaskCommunicatorManager taskCommunicatorManager =
new TaskCommunicatorManager(taskCommunicator, appContext, mock(TaskHeartbeatHandler.class),
mock(ContainerHeartbeatHandler.class));
try {
taskCommunicatorManager.init(conf);
taskCommunicatorManager.start();
// Invoking a couple of random methods.
DAG mockDag = mock(DAG.class, RETURNS_DEEP_STUBS);
when(mockDag.getID().getId()).thenReturn(1);
taskCommunicatorManager.dagComplete(mockDag);
ArgumentCaptor<Event> argumentCaptor = ArgumentCaptor.forClass(Event.class);
verify(eventHandler, times(1)).handle(argumentCaptor.capture());
Event rawEvent = argumentCaptor.getValue();
assertTrue(rawEvent instanceof DAGAppMasterEventUserServiceFatalError);
DAGAppMasterEventUserServiceFatalError event =
(DAGAppMasterEventUserServiceFatalError) rawEvent;
assertEquals(DAGAppMasterEventType.TASK_COMMUNICATOR_SERVICE_FATAL_ERROR, event.getType());
assertTrue(event.getError().getMessage().contains("TestException_" + "dagComplete"));
assertTrue(event.getDiagnosticInfo().contains("DAG completion"));
assertTrue(event.getDiagnosticInfo().contains(expectedId));
when(appContext.getAllContainers().get(any(ContainerId.class)).getContainer().getNodeId())
.thenReturn(mock(NodeId.class));
taskCommunicatorManager.registerRunningContainer(mock(ContainerId.class), 0);
argumentCaptor = ArgumentCaptor.forClass(Event.class);
verify(eventHandler, times(2)).handle(argumentCaptor.capture());
rawEvent = argumentCaptor.getAllValues().get(1);
assertTrue(rawEvent instanceof DAGAppMasterEventUserServiceFatalError);
event = (DAGAppMasterEventUserServiceFatalError) rawEvent;
assertEquals(DAGAppMasterEventType.TASK_COMMUNICATOR_SERVICE_FATAL_ERROR, event.getType());
assertTrue(
event.getError().getMessage().contains("TestException_" + "registerRunningContainer"));
assertTrue(event.getDiagnosticInfo().contains("registering running Container"));
assertTrue(event.getDiagnosticInfo().contains(expectedId));
} finally {
taskCommunicatorManager.stop();
}
}
private static class ExceptionAnswer implements Answer {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
Method method = invocation.getMethod();
if (method.getDeclaringClass().equals(TaskCommunicator.class) &&
!method.getName().equals("getContext") && !method.getName().equals("initialize") &&
!method.getName().equals("start") && !method.getName().equals("shutdown")) {
throw new RuntimeException("TestException_" + method.getName());
} else {
return invocation.callRealMethod();
}
}
}
static class TaskCommManagerForMultipleCommTest extends TaskCommunicatorManager {
// All variables setup as static since methods being overridden are invoked by the ContainerLauncherRouter ctor,
// and regular variables will not be initialized at this point.
private static final AtomicInteger numTaskComms = new AtomicInteger(0);
private static final Set<Integer> taskCommIndices = new HashSet<>();
private static final TaskCommunicator yarnTaskComm = mock(TaskCommunicator.class);
private static final TaskCommunicator uberTaskComm = mock(TaskCommunicator.class);
private static final AtomicBoolean yarnTaskCommCreated = new AtomicBoolean(false);
private static final AtomicBoolean uberTaskCommCreated = new AtomicBoolean(false);
private static final List<TaskCommunicatorContext> taskCommContexts =
new LinkedList<>();
private static final List<String> taskCommNames = new LinkedList<>();
private static final List<TaskCommunicator> testTaskComms = new LinkedList<>();
public static void reset() {
numTaskComms.set(0);
taskCommIndices.clear();
yarnTaskCommCreated.set(false);
uberTaskCommCreated.set(false);
taskCommContexts.clear();
taskCommNames.clear();
testTaskComms.clear();
}
public TaskCommManagerForMultipleCommTest(AppContext context,
TaskHeartbeatHandler thh,
ContainerHeartbeatHandler chh,
List<NamedEntityDescriptor> taskCommunicatorDescriptors) throws TezException {
super(context, thh, chh, taskCommunicatorDescriptors);
}
@Override
TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor,
int taskCommIndex) throws TezException {
numTaskComms.incrementAndGet();
boolean added = taskCommIndices.add(taskCommIndex);
assertTrue("Cannot add multiple taskComms with the same index", added);
taskCommNames.add(taskCommDescriptor.getEntityName());
return super.createTaskCommunicator(taskCommDescriptor, taskCommIndex);
}
@Override
TaskCommunicator createDefaultTaskCommunicator(
TaskCommunicatorContext taskCommunicatorContext) {
taskCommContexts.add(taskCommunicatorContext);
yarnTaskCommCreated.set(true);
testTaskComms.add(yarnTaskComm);
return yarnTaskComm;
}
@Override
TaskCommunicator createUberTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
taskCommContexts.add(taskCommunicatorContext);
uberTaskCommCreated.set(true);
testTaskComms.add(uberTaskComm);
return uberTaskComm;
}
@Override
TaskCommunicator createCustomTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext,
NamedEntityDescriptor taskCommDescriptor) throws TezException {
taskCommContexts.add(taskCommunicatorContext);
TaskCommunicator spyComm =
spy(super.createCustomTaskCommunicator(taskCommunicatorContext, taskCommDescriptor));
testTaskComms.add(spyComm);
return spyComm;
}
public static int getNumTaskComms() {
return numTaskComms.get();
}
public static boolean getYarnTaskCommCreated() {
return yarnTaskCommCreated.get();
}
public static boolean getUberTaskCommCreated() {
return uberTaskCommCreated.get();
}
public static TaskCommunicatorContext getTaskCommContext(int taskCommIndex) {
return taskCommContexts.get(taskCommIndex);
}
public static String getTaskCommName(int taskCommIndex) {
return taskCommNames.get(taskCommIndex);
}
public static TaskCommunicator getTestTaskComm(int taskCommIndex) {
return testTaskComms.get(taskCommIndex);
}
}
public static class FakeTaskComm extends TaskCommunicator {
public FakeTaskComm(TaskCommunicatorContext taskCommunicatorContext) {
super(taskCommunicatorContext);
}
@Override
public void registerRunningContainer(ContainerId containerId, String hostname, int port) {
}
@Override
public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason, String diagnostics) {
}
@Override
public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
Map<String, LocalResource> additionalResources,
Credentials credentials, boolean credentialsChanged,
int priority) {
}
@Override
public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID,
TaskAttemptEndReason endReason, String diagnostics) {
}
@Override
public InetSocketAddress getAddress() {
return null;
}
@Override
public void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
}
@Override
public void dagComplete(int dagIdentifier) {
}
@Override
public Object getMetaInfo() {
return null;
}
}
private static final String DAG_NAME = "dagName";
private static final int DAG_INDEX = 1;
public static class TaskCommForFailureTest extends TaskCommunicator {
public TaskCommForFailureTest(
TaskCommunicatorContext taskCommunicatorContext) {
super(taskCommunicatorContext);
}
@Override
public void registerRunningContainer(ContainerId containerId, String hostname, int port) throws
ServicePluginException {
getContext()
.reportError(ServicePluginErrorDefaults.SERVICE_UNAVAILABLE, "ReportError", new DagInfoImplForTest(DAG_INDEX, DAG_NAME));
}
@Override
public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason,
@Nullable String diagnostics) throws ServicePluginException {
}
@Override
public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
Map<String, LocalResource> additionalResources,
Credentials credentials, boolean credentialsChanged,
int priority) throws ServicePluginException {
}
@Override
public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID,
TaskAttemptEndReason endReason,
@Nullable String diagnostics) throws
ServicePluginException {
}
@Override
public InetSocketAddress getAddress() throws ServicePluginException {
return null;
}
@Override
public void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws ServicePluginException {
}
@Override
public void dagComplete(int dagIdentifier) throws ServicePluginException {
getContext().reportError(ServicePluginErrorDefaults.INCONSISTENT_STATE, "ReportedFatalError", null);
}
@Override
public Object getMetaInfo() throws ServicePluginException {
return null;
}
}
}