| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.dag.app.web; |
| |
| import static org.mockito.Mockito.any; |
| import static org.mockito.Mockito.anyString; |
| import static org.mockito.Mockito.doNothing; |
| import static org.mockito.Mockito.doReturn; |
| import static org.mockito.Mockito.eq; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.reset; |
| import static org.mockito.Mockito.spy; |
| import static org.mockito.Mockito.verify; |
| import static org.mockito.Mockito.when; |
| |
| import javax.servlet.http.HttpServletRequest; |
| import javax.servlet.http.HttpServletResponse; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.TreeMap; |
| |
| import com.google.common.collect.ImmutableList; |
| import com.google.common.collect.ImmutableMap; |
| import com.google.common.collect.Maps; |
| import org.apache.hadoop.conf.Configuration; |
| import org.apache.hadoop.security.UserGroupInformation; |
| import org.apache.hadoop.yarn.webapp.Controller; |
| import org.apache.tez.common.counters.TezCounters; |
| import org.apache.tez.common.security.ACLManager; |
| import org.apache.tez.dag.api.TezConfiguration; |
| import org.apache.tez.dag.api.client.ProgressBuilder; |
| import org.apache.tez.dag.api.oldrecords.TaskAttemptState; |
| import org.apache.tez.dag.api.oldrecords.TaskState; |
| import org.apache.tez.dag.app.AppContext; |
| import org.apache.tez.dag.app.dag.DAG; |
| import org.apache.tez.dag.app.dag.DAGState; |
| import org.apache.tez.dag.app.dag.Task; |
| import org.apache.tez.dag.app.dag.TaskAttempt; |
| import org.apache.tez.dag.app.dag.Vertex; |
| import org.apache.tez.dag.app.dag.VertexState; |
| import org.apache.tez.dag.records.TezDAGID; |
| import org.apache.tez.dag.records.TezTaskAttemptID; |
| import org.apache.tez.dag.records.TezTaskID; |
| import org.apache.tez.dag.records.TezVertexID; |
| import org.junit.Assert; |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.mockito.ArgumentCaptor; |
| import org.mockito.Captor; |
| import org.mockito.MockitoAnnotations; |
| |
| public class TestAMWebController { |
| AppContext mockAppContext; |
| Controller.RequestContext mockRequestContext; |
| HttpServletResponse mockResponse; |
| HttpServletRequest mockRequest; |
| String[] userGroups = {}; |
| |
| @Before |
| public void setup() { |
| MockitoAnnotations.initMocks(this); |
| |
| mockAppContext = mock(AppContext.class); |
| Configuration conf = new Configuration(false); |
| conf.set(TezConfiguration.TEZ_HISTORY_URL_BASE, "http://uihost:9001/foo"); |
| when(mockAppContext.getAMConf()).thenReturn(conf); |
| mockRequestContext = mock(Controller.RequestContext.class); |
| mockResponse = mock(HttpServletResponse.class); |
| mockRequest = mock(HttpServletRequest.class); |
| } |
| |
| @Test |
| public void testEncodeHeaders() { |
| String validOrigin = "http://localhost:12345"; |
| String encodedValidOrigin = AMWebController.encodeHeader(validOrigin); |
| Assert.assertEquals("Valid origin encoding should match exactly", |
| validOrigin, encodedValidOrigin); |
| |
| String httpResponseSplitOrigin = validOrigin + " \nSecondHeader: value"; |
| String encodedResponseSplitOrigin = |
| AMWebController.encodeHeader(httpResponseSplitOrigin); |
| Assert.assertEquals("Http response split origin should be protected against", |
| validOrigin, encodedResponseSplitOrigin); |
| |
| // Test Origin List |
| String validOriginList = "http://foo.example.com:12345 http://bar.example.com:12345"; |
| String encodedValidOriginList = AMWebController |
| .encodeHeader(validOriginList); |
| Assert.assertEquals("Valid origin list encoding should match exactly", |
| validOriginList, encodedValidOriginList); |
| } |
| |
| @Test(timeout = 5000) |
| public void testCorsHeadersWithOrigin() { |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| String originURL = "http://origin.com:8080"; |
| |
| doReturn(mockResponse).when(spy).response(); |
| |
| doReturn(mockRequest).when(spy).request(); |
| doReturn(originURL).when(mockRequest).getHeader(AMWebController.ORIGIN); |
| |
| spy.setCorsHeaders(); |
| verify(mockResponse).setHeader("Access-Control-Allow-Origin", originURL); |
| } |
| |
| @Test(timeout = 5000) |
| public void testCorsHeadersAreSet() { |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| doReturn(mockRequest).when(spy).request(); |
| doReturn(mockResponse).when(spy).response(); |
| spy.setCorsHeaders(); |
| |
| verify(mockResponse).setHeader("Access-Control-Allow-Origin", "http://uihost:9001"); |
| verify(mockResponse).setHeader("Access-Control-Allow-Credentials", "true"); |
| verify(mockResponse).setHeader("Access-Control-Allow-Methods", "GET, HEAD"); |
| verify(mockResponse).setHeader("Access-Control-Allow-Headers", |
| "X-Requested-With,Content-Type,Accept,Origin"); |
| } |
| |
| @Test (timeout = 5000) |
| public void sendErrorResponseIfNoAccess() throws Exception { |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| |
| doReturn(false).when(spy).hasAccess(); |
| doNothing().when(spy).setCorsHeaders(); |
| doReturn(mockResponse).when(spy).response(); |
| doReturn(mockRequest).when(spy).request(); |
| doReturn("dummyuser").when(mockRequest).getRemoteUser(); |
| |
| spy.getDagProgress(); |
| verify(mockResponse).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), anyString()); |
| reset(mockResponse); |
| |
| spy.getVertexProgress(); |
| verify(mockResponse).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), anyString()); |
| reset(mockResponse); |
| |
| spy.getVertexProgresses(); |
| verify(mockResponse).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), anyString()); |
| } |
| |
| @Captor |
| ArgumentCaptor<Map<String, AMWebController.ProgressInfo>> singleResultCaptor; |
| |
| @Test (timeout = 5000) |
| public void testDagProgressResponse() { |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| DAG mockDAG = mock(DAG.class); |
| |
| doReturn(true).when(spy).hasAccess(); |
| doNothing().when(spy).setCorsHeaders(); |
| doReturn("42").when(spy).$(WebUIService.DAG_ID); |
| doReturn(mockResponse).when(spy).response(); |
| doReturn(TezDAGID.fromString("dag_1422960590892_0007_42")).when(mockDAG).getID(); |
| doReturn(66.0f).when(mockDAG).getCompletedTaskProgress(); |
| doReturn(mockDAG).when(mockAppContext).getCurrentDAG(); |
| doNothing().when(spy).renderJSON(any()); |
| spy.getDagProgress(); |
| verify(spy).renderJSON(singleResultCaptor.capture()); |
| |
| final Map<String, AMWebController.ProgressInfo> result = singleResultCaptor.getValue(); |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("dagProgress")); |
| AMWebController.ProgressInfo progressInfo = result.get("dagProgress"); |
| Assert.assertTrue("dag_1422960590892_0007_42".equals(progressInfo.getId())); |
| Assert.assertEquals(66.0, progressInfo.getProgress(), 0.1); |
| } |
| |
| @Test (timeout = 5000) |
| public void testVertexProgressResponse() { |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| DAG mockDAG = mock(DAG.class); |
| Vertex mockVertex = mock(Vertex.class); |
| |
| doReturn(true).when(spy).hasAccess(); |
| doReturn("42").when(spy).$(WebUIService.DAG_ID); |
| doReturn("43").when(spy).$(WebUIService.VERTEX_ID); |
| doReturn(mockResponse).when(spy).response(); |
| |
| doReturn(TezDAGID.fromString("dag_1422960590892_0007_42")).when(mockDAG).getID(); |
| doReturn(mockDAG).when(mockAppContext).getCurrentDAG(); |
| doReturn(mockVertex).when(mockDAG).getVertex(any(TezVertexID.class)); |
| doReturn(66.0f).when(mockVertex).getCompletedTaskProgress(); |
| doNothing().when(spy).renderJSON(any()); |
| doNothing().when(spy).setCorsHeaders(); |
| |
| spy.getVertexProgress(); |
| verify(spy).renderJSON(singleResultCaptor.capture()); |
| |
| final Map<String, AMWebController.ProgressInfo> result = singleResultCaptor.getValue(); |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("vertexProgress")); |
| AMWebController.ProgressInfo progressInfo = result.get("vertexProgress"); |
| Assert.assertTrue("vertex_1422960590892_0007_42_43".equals(progressInfo.getId())); |
| Assert.assertEquals(66.0f, progressInfo.getProgress(), 0.1); |
| } |
| |
| @Test (timeout = 5000) |
| public void testHasAccessWithAclsDisabled() { |
| Configuration conf = new Configuration(false); |
| conf.setBoolean(TezConfiguration.TEZ_AM_ACLS_ENABLED, false); |
| ACLManager aclManager = new ACLManager("amUser", conf); |
| |
| when(mockAppContext.getAMACLManager()).thenReturn(aclManager); |
| |
| Assert.assertEquals(true, AMWebController._hasAccess(null, mockAppContext)); |
| |
| UserGroupInformation mockUser = UserGroupInformation.createUserForTesting( |
| "mockUser", userGroups); |
| Assert.assertEquals(true, AMWebController._hasAccess(mockUser, mockAppContext)); |
| } |
| |
| @Test (timeout = 5000) |
| public void testHasAccess() { |
| Configuration conf = new Configuration(false); |
| conf.setBoolean(TezConfiguration.TEZ_AM_ACLS_ENABLED, true); |
| ACLManager aclManager = new ACLManager("amUser", conf); |
| |
| when(mockAppContext.getAMACLManager()).thenReturn(aclManager); |
| |
| Assert.assertEquals(false, AMWebController._hasAccess(null, mockAppContext)); |
| |
| UserGroupInformation mockUser = UserGroupInformation.createUserForTesting( |
| "mockUser", userGroups); |
| Assert.assertEquals(false, AMWebController._hasAccess(mockUser, mockAppContext)); |
| |
| UserGroupInformation testUser = UserGroupInformation.createUserForTesting( |
| "amUser", userGroups); |
| Assert.assertEquals(true, AMWebController._hasAccess(testUser, mockAppContext)); |
| } |
| |
| |
| // AM Webservice Version 2 |
| //ArgumentCaptor<Map<String, Object>> returnResultCaptor; |
| @Captor |
| ArgumentCaptor<Map<String,Object>> returnResultCaptor; |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetDagInfo() { |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| DAG mockDAG = mock(DAG.class); |
| |
| |
| doReturn(TezDAGID.fromString("dag_1422960590892_0007_42")).when(mockDAG).getID(); |
| doReturn(66.0f).when(mockDAG).getCompletedTaskProgress(); |
| doReturn(DAGState.RUNNING).when(mockDAG).getState(); |
| TezCounters counters = new TezCounters(); |
| counters.addGroup("g1", "g1"); |
| counters.addGroup("g2", "g2"); |
| counters.addGroup("g3", "g3"); |
| counters.addGroup("g4", "g4"); |
| counters.findCounter("g1", "g1_c1").setValue(100); |
| counters.findCounter("g1", "g1_c2").setValue(100); |
| counters.findCounter("g2", "g2_c3").setValue(100); |
| counters.findCounter("g2", "g2_c4").setValue(100); |
| counters.findCounter("g3", "g3_c5").setValue(100); |
| counters.findCounter("g3", "g3_c6").setValue(100); |
| |
| doReturn(counters).when(mockDAG).getAllCounters(); |
| doReturn(counters).when(mockDAG).getCachedCounters(); |
| |
| doReturn(true).when(spy).setupResponse(); |
| doReturn(mockDAG).when(spy).checkAndGetDAGFromRequest(); |
| doNothing().when(spy).renderJSON(any()); |
| |
| Map<String, Set<String>> counterNames = new HashMap<String, Set<String>>(); |
| counterNames.put("*", null); |
| doReturn(counterNames).when(spy).getCounterListFromRequest(); |
| |
| spy.getDagInfo(); |
| verify(spy).renderJSON(returnResultCaptor.capture()); |
| |
| final Map<String, Object> result = returnResultCaptor.getValue(); |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("dag")); |
| Map<String, String> dagInfo = (Map<String, String>) result.get("dag"); |
| |
| Assert.assertEquals(4, dagInfo.size()); |
| Assert.assertTrue("dag_1422960590892_0007_42".equals(dagInfo.get("id"))); |
| Assert.assertEquals("66.0", dagInfo.get("progress")); |
| Assert.assertEquals("RUNNING", dagInfo.get("status")); |
| Assert.assertNotNull(dagInfo.get("counters")); |
| |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetVerticesInfoGetAll() { |
| Vertex mockVertex1 = createMockVertex("vertex_1422960590892_0007_42_00", VertexState.RUNNING, |
| 0.33f, 3); |
| Vertex mockVertex2 = createMockVertex("vertex_1422960590892_0007_42_01", VertexState.SUCCEEDED, |
| 1.0f, 5); |
| |
| final Map<String, Object> result = getVerticesTestHelper(0, mockVertex1, mockVertex2); |
| |
| Assert.assertEquals(1, result.size()); |
| |
| Assert.assertTrue(result.containsKey("vertices")); |
| ArrayList<Map<String, String>> verticesInfo = (ArrayList<Map<String, String>>) result.get("vertices"); |
| Assert.assertEquals(2, verticesInfo.size()); |
| |
| Map<String, String> vertex1Result = verticesInfo.get(0); |
| Map<String, String> vertex2Result = verticesInfo.get(1); |
| |
| verifySingleVertexResult(mockVertex1, vertex1Result); |
| verifySingleVertexResult(mockVertex2, vertex2Result); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetVerticesInfoGetPartial() { |
| Vertex mockVertex1 = createMockVertex("vertex_1422960590892_0007_42_00", VertexState.RUNNING, |
| 0.33f, 3); |
| Vertex mockVertex2 = createMockVertex("vertex_1422960590892_0007_42_01", VertexState.SUCCEEDED, |
| 1.0f, 5); |
| |
| final Map<String, Object> result = getVerticesTestHelper(1, mockVertex1, mockVertex2); |
| |
| Assert.assertEquals(1, result.size()); |
| |
| Assert.assertTrue(result.containsKey("vertices")); |
| List<Map<String, String>> verticesInfo = (List<Map<String, String>>) result.get("vertices"); |
| Assert.assertEquals(1, verticesInfo.size()); |
| |
| Map<String, String> vertex1Result = verticesInfo.get(0); |
| |
| verifySingleVertexResult(mockVertex1, vertex1Result); |
| } |
| |
| Map<String, Object> getVerticesTestHelper(int numVerticesRequested, Vertex mockVertex1, |
| Vertex mockVertex2) { |
| DAG mockDAG = mock(DAG.class); |
| doReturn(TezDAGID.fromString("dag_1422960590892_0007_42")).when(mockDAG).getID(); |
| |
| TezVertexID vertexId1 = mockVertex1.getVertexId(); |
| doReturn(mockVertex1).when(mockDAG).getVertex(vertexId1); |
| TezVertexID vertexId2 = mockVertex2.getVertexId(); |
| doReturn(mockVertex2).when(mockDAG).getVertex(vertexId2); |
| |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| |
| doReturn(ImmutableMap.of( |
| mockVertex1.getVertexId(), mockVertex1, |
| mockVertex2.getVertexId(), mockVertex2 |
| )).when(mockDAG).getVertices(); |
| |
| doReturn(true).when(spy).setupResponse(); |
| doReturn(mockDAG).when(spy).checkAndGetDAGFromRequest(); |
| |
| Map<String, Set<String>> counterList = new TreeMap<String, Set<String>>(); |
| doReturn(counterList).when(spy).getCounterListFromRequest(); |
| |
| List<Integer> requested; |
| if (numVerticesRequested == 0) { |
| requested = ImmutableList.of(); |
| } else { |
| requested = ImmutableList.of(mockVertex1.getVertexId().getId()); |
| } |
| |
| doReturn(requested).when(spy).getVertexIDsFromRequest(); |
| doNothing().when(spy).renderJSON(any()); |
| |
| spy.getVerticesInfo(); |
| verify(spy).renderJSON(returnResultCaptor.capture()); |
| |
| return returnResultCaptor.getValue(); |
| } |
| |
| private Vertex createMockVertex(String vertexIDStr, VertexState status, float progress, |
| int taskCounts) { |
| ProgressBuilder pb = new ProgressBuilder(); |
| pb.setTotalTaskCount(taskCounts); |
| pb.setSucceededTaskCount(taskCounts * 2); |
| pb.setFailedTaskAttemptCount(taskCounts * 3); |
| pb.setKilledTaskAttemptCount(taskCounts * 4); |
| pb.setRunningTaskCount(taskCounts * 5); |
| |
| Vertex mockVertex = mock(Vertex.class); |
| doReturn(TezVertexID.fromString(vertexIDStr)).when(mockVertex).getVertexId(); |
| doReturn(status).when(mockVertex).getState(); |
| doReturn(progress).when(mockVertex).getProgress(); |
| doReturn(pb).when(mockVertex).getVertexProgress(); |
| doReturn(1L).when(mockVertex).getInitTime(); |
| doReturn(1L).when(mockVertex).getStartTime(); |
| doReturn(2L).when(mockVertex).getFinishTime(); |
| doReturn(1L).when(mockVertex).getFirstTaskStartTime(); |
| doReturn(2L).when(mockVertex).getLastTaskFinishTime(); |
| |
| TezCounters counters = new TezCounters(); |
| counters.addGroup("g1", "g1"); |
| counters.addGroup("g2", "g2"); |
| counters.addGroup("g3", "g3"); |
| counters.addGroup("g4", "g4"); |
| counters.findCounter("g1", "g1_c1").setValue(100); |
| counters.findCounter("g1", "g1_c2").setValue(100); |
| counters.findCounter("g2", "g2_c3").setValue(100); |
| counters.findCounter("g2", "g2_c4").setValue(100); |
| counters.findCounter("g3", "g3_c5").setValue(100); |
| counters.findCounter("g3", "g3_c6").setValue(100); |
| |
| doReturn(counters).when(mockVertex).getAllCounters(); |
| doReturn(counters).when(mockVertex).getCachedCounters(); |
| |
| return mockVertex; |
| } |
| |
| |
| private void verifySingleVertexResult(Vertex mockVertex2, Map<String, String> vertex2Result) { |
| ProgressBuilder progress; |
| Assert.assertEquals(mockVertex2.getVertexId().toString(), vertex2Result.get("id")); |
| Assert.assertEquals(mockVertex2.getState().toString(), vertex2Result.get("status")); |
| Assert.assertEquals(Float.toString(mockVertex2.getCompletedTaskProgress()), vertex2Result.get("progress")); |
| progress = mockVertex2.getVertexProgress(); |
| Assert.assertEquals(Integer.toString(progress.getTotalTaskCount()), |
| vertex2Result.get("totalTasks")); |
| Assert.assertEquals(Integer.toString(progress.getRunningTaskCount()), |
| vertex2Result.get("runningTasks")); |
| Assert.assertEquals(Integer.toString(progress.getSucceededTaskCount()), |
| vertex2Result.get("succeededTasks")); |
| Assert.assertEquals(Integer.toString(progress.getKilledTaskAttemptCount()), |
| vertex2Result.get("killedTaskAttempts")); |
| Assert.assertEquals(Integer.toString(progress.getFailedTaskAttemptCount()), |
| vertex2Result.get("failedTaskAttempts")); |
| String str0 = Long.toString(mockVertex2.getInitTime()); |
| String str1 = vertex2Result.get("initTime"); |
| Assert.assertEquals(Long.toString(mockVertex2.getInitTime()), |
| vertex2Result.get("initTime")); |
| Assert.assertEquals(Long.toString(mockVertex2.getStartTime()), |
| vertex2Result.get("startTime")); |
| Assert.assertEquals(Long.toString(mockVertex2.getFinishTime()), |
| vertex2Result.get("finishTime")); |
| Assert.assertEquals(Long.toString(mockVertex2.getFirstTaskStartTime()), |
| vertex2Result.get("firstTaskStartTime")); |
| Assert.assertEquals(Long.toString(mockVertex2.getLastTaskFinishTime()), |
| vertex2Result.get("lastTaskFinishTime")); |
| } |
| |
| //-- Get Tasks Info Tests ----------------------------------------------------------------------- |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetTasksInfoWithTaskIds() { |
| List <Task> tasks = createMockTasks(); |
| List <Integer> vertexMinIds = Arrays.asList(); |
| List <List <Integer>> taskMinIds = Arrays.asList(Arrays.asList(0, 0), |
| Arrays.asList(0, 3), |
| Arrays.asList(0, 1)); |
| |
| // Fetch All |
| Map<String, Object> result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, |
| AMWebController.MAX_QUERIED); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| ArrayList<Map<String, String>> tasksInfo = (ArrayList<Map<String, String>>) result. |
| get("tasks"); |
| Assert.assertEquals(3, tasksInfo.size()); |
| |
| verifySingleTaskResult(tasks.get(0), tasksInfo.get(0)); |
| verifySingleTaskResult(tasks.get(3), tasksInfo.get(1)); |
| verifySingleTaskResult(tasks.get(1), tasksInfo.get(2)); |
| |
| // With limit |
| result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, 2); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| tasksInfo = (ArrayList<Map<String, String>>) result.get("tasks"); |
| Assert.assertEquals(2, tasksInfo.size()); |
| |
| verifySingleTaskResult(tasks.get(0), tasksInfo.get(0)); |
| verifySingleTaskResult(tasks.get(3), tasksInfo.get(1)); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetTasksInfoGracefulTaskFetch() { |
| List <Task> tasks = createMockTasks(); |
| List <Integer> vertexMinIds = Arrays.asList(); |
| List <List <Integer>> taskMinIds = Arrays.asList(Arrays.asList(0, 0), |
| Arrays.asList(0, 6), |
| Arrays.asList(0, 1)); |
| |
| // Fetch All |
| Map<String, Object> result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, |
| AMWebController.MAX_QUERIED); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| ArrayList<Map<String, String>> tasksInfo = (ArrayList<Map<String, String>>) result. |
| get("tasks"); |
| Assert.assertEquals(2, tasksInfo.size()); |
| |
| verifySingleTaskResult(tasks.get(0), tasksInfo.get(0)); |
| verifySingleTaskResult(tasks.get(1), tasksInfo.get(1)); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetTasksInfoWithVertexId() { |
| List <Task> tasks = createMockTasks(); |
| List <Integer> vertexMinIds = Arrays.asList(0); |
| List <List <Integer>> taskMinIds = Arrays.asList(); |
| |
| Map<String, Object> result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, |
| AMWebController.MAX_QUERIED); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| ArrayList<Map<String, String>> tasksInfo = (ArrayList<Map<String, String>>) result. |
| get("tasks"); |
| Assert.assertEquals(4, tasksInfo.size()); |
| |
| sortMapList(tasksInfo, "id"); |
| verifySingleTaskResult(tasks.get(0), tasksInfo.get(0)); |
| verifySingleTaskResult(tasks.get(1), tasksInfo.get(1)); |
| verifySingleTaskResult(tasks.get(2), tasksInfo.get(2)); |
| verifySingleTaskResult(tasks.get(3), tasksInfo.get(3)); |
| |
| // With limit |
| result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, 2); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| tasksInfo = (ArrayList<Map<String, String>>) result.get("tasks"); |
| Assert.assertEquals(2, tasksInfo.size()); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetTasksInfoWithJustDAGId() { |
| List <Task> tasks = createMockTasks(); |
| List <Integer> vertexMinIds = Arrays.asList(); |
| List <List <Integer>> taskMinIds = Arrays.asList(); |
| |
| Map<String, Object> result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, |
| AMWebController.MAX_QUERIED); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| ArrayList<Map<String, String>> tasksInfo = (ArrayList<Map<String, String>>) result. |
| get("tasks"); |
| Assert.assertEquals(4, tasksInfo.size()); |
| |
| sortMapList(tasksInfo, "id"); |
| verifySingleTaskResult(tasks.get(0), tasksInfo.get(0)); |
| verifySingleTaskResult(tasks.get(1), tasksInfo.get(1)); |
| verifySingleTaskResult(tasks.get(2), tasksInfo.get(2)); |
| verifySingleTaskResult(tasks.get(3), tasksInfo.get(3)); |
| |
| // With limit |
| result = getTasksTestHelper(tasks, taskMinIds, vertexMinIds, 2); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("tasks")); |
| |
| tasksInfo = (ArrayList<Map<String, String>>) result.get("tasks"); |
| Assert.assertEquals(2, tasksInfo.size()); |
| } |
| |
| private void sortMapList(ArrayList<Map<String, String>> list, String propertyName) { |
| class MapComparator implements Comparator<Map<String, String>> { |
| private final String key; |
| |
| public MapComparator(String key) { |
| this.key = key; |
| } |
| |
| public int compare(Map<String, String> first, Map<String, String> second) { |
| String firstValue = first.get(key); |
| String secondValue = second.get(key); |
| return firstValue.compareTo(secondValue); |
| } |
| } |
| |
| Collections.sort(list, new MapComparator(propertyName)); |
| } |
| |
| Map<String, Object> getTasksTestHelper(List<Task> tasks, List <List <Integer>> taskMinIds, |
| List<Integer> vertexMinIds, Integer limit) { |
| //Creating mock DAG |
| DAG mockDAG = mock(DAG.class); |
| doReturn(TezDAGID.fromString("dag_1441301219877_0109_1")).when(mockDAG).getID(); |
| |
| //Creating mock vertex and attaching to mock DAG |
| TezVertexID vertexID = TezVertexID.fromString("vertex_1441301219877_0109_1_00"); |
| Vertex mockVertex = mock(Vertex.class); |
| doReturn(vertexID).when(mockVertex).getVertexId(); |
| |
| doReturn(mockVertex).when(mockDAG).getVertex(vertexID); |
| doReturn(ImmutableMap.of( |
| vertexID, mockVertex |
| )).when(mockDAG).getVertices(); |
| |
| //Creating mock tasks and attaching to mock vertex |
| Map<TezTaskID, Task> taskMap = Maps.newHashMap(); |
| for(Task task : tasks) { |
| TezTaskID taskId = task.getTaskID(); |
| int taskIndex = taskId.getId(); |
| doReturn(task).when(mockVertex).getTask(taskIndex); |
| taskMap.put(taskId, task); |
| } |
| doReturn(taskMap).when(mockVertex).getTasks(); |
| |
| //Creates & setup controller spy |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| doReturn(true).when(spy).setupResponse(); |
| doNothing().when(spy).renderJSON(any()); |
| |
| // Set mock query params |
| doReturn(limit).when(spy).getQueryParamInt(WebUIService.LIMIT); |
| doReturn(vertexMinIds).when(spy).getIntegersFromRequest(WebUIService.VERTEX_ID, limit); |
| doReturn(taskMinIds).when(spy).getIDsFromRequest(WebUIService.TASK_ID, limit, 2); |
| |
| // Set function mocks |
| doReturn(mockDAG).when(spy).checkAndGetDAGFromRequest(); |
| |
| Map<String, Set<String>> counterList = new TreeMap<String, Set<String>>(); |
| doReturn(counterList).when(spy).getCounterListFromRequest(); |
| |
| spy.getTasksInfo(); |
| verify(spy).renderJSON(returnResultCaptor.capture()); |
| |
| return returnResultCaptor.getValue(); |
| } |
| |
| private List<Task> createMockTasks() { |
| Task mockTask1 = createMockTask("task_1441301219877_0109_1_00_000000", TaskState.RUNNING, |
| 0.33f); |
| Task mockTask2 = createMockTask("task_1441301219877_0109_1_00_000001", TaskState.SUCCEEDED, |
| 1.0f); |
| Task mockTask3 = createMockTask("task_1441301219877_0109_1_00_000002", TaskState.SUCCEEDED, |
| .8f); |
| Task mockTask4 = createMockTask("task_1441301219877_0109_1_00_000003", TaskState.SUCCEEDED, |
| .8f); |
| |
| List <Task> tasks = Arrays.asList(mockTask1, mockTask2, mockTask3, mockTask4); |
| return tasks; |
| } |
| |
| private Task createMockTask(String taskIDStr, TaskState status, float progress) { |
| Task mockTask = mock(Task.class); |
| |
| doReturn(TezTaskID.fromString(taskIDStr)).when(mockTask).getTaskID(); |
| doReturn(status).when(mockTask).getState(); |
| doReturn(progress).when(mockTask).getProgress(); |
| |
| TezCounters counters = new TezCounters(); |
| counters.addGroup("g1", "g1"); |
| counters.addGroup("g2", "g2"); |
| counters.addGroup("g3", "g3"); |
| counters.addGroup("g4", "g4"); |
| counters.findCounter("g1", "g1_c1").setValue(101); |
| counters.findCounter("g1", "g1_c2").setValue(102); |
| counters.findCounter("g2", "g2_c3").setValue(103); |
| counters.findCounter("g2", "g2_c4").setValue(104); |
| counters.findCounter("g3", "g3_c5").setValue(105); |
| counters.findCounter("g3", "g3_c6").setValue(106); |
| |
| doReturn(counters).when(mockTask).getCounters(); |
| |
| return mockTask; |
| } |
| |
| private void verifySingleTaskResult(Task mockTask, Map<String, String> taskResult) { |
| Assert.assertEquals(3, taskResult.size()); |
| Assert.assertEquals(mockTask.getTaskID().toString(), taskResult.get("id")); |
| Assert.assertEquals(mockTask.getState().toString(), taskResult.get("status")); |
| Assert.assertEquals(Float.toString(mockTask.getProgress()), taskResult.get("progress")); |
| } |
| |
| //-- Get Attempts Info Tests ----------------------------------------------------------------------- |
| |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testGetAttemptsInfoWithIds() { |
| List <TaskAttempt> attempts = createMockAttempts(); |
| List <Integer> vertexMinIds = Arrays.asList(); |
| List <Integer> taskMinIds = Arrays.asList(); |
| List <List <Integer>> attemptMinIds = Arrays.asList(Arrays.asList(0, 0, 0), |
| Arrays.asList(0, 0, 1), |
| Arrays.asList(0, 0, 2), |
| Arrays.asList(0, 0, 3)); |
| |
| // Fetch All |
| Map<String, Object> result = getAttemptsTestHelper(attempts, attemptMinIds, vertexMinIds, |
| taskMinIds, AMWebController.MAX_QUERIED); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("attempts")); |
| |
| ArrayList<Map<String, String>> attemptsInfo = (ArrayList<Map<String, String>>) result. |
| get("attempts"); |
| Assert.assertEquals(4, attemptsInfo.size()); |
| |
| verifySingleAttemptResult(attempts.get(0), attemptsInfo.get(0)); |
| verifySingleAttemptResult(attempts.get(1), attemptsInfo.get(1)); |
| verifySingleAttemptResult(attempts.get(2), attemptsInfo.get(2)); |
| verifySingleAttemptResult(attempts.get(3), attemptsInfo.get(3)); |
| |
| // With limit |
| result = getAttemptsTestHelper(attempts, attemptMinIds, vertexMinIds, taskMinIds, 2); |
| |
| Assert.assertEquals(1, result.size()); |
| Assert.assertTrue(result.containsKey("attempts")); |
| |
| attemptsInfo = (ArrayList<Map<String, String>>) result.get("attempts"); |
| Assert.assertEquals(2, attemptsInfo.size()); |
| |
| verifySingleAttemptResult(attempts.get(0), attemptsInfo.get(0)); |
| verifySingleAttemptResult(attempts.get(1), attemptsInfo.get(1)); |
| } |
| |
| Map<String, Object> getAttemptsTestHelper(List<TaskAttempt> attempts, List <List <Integer>> attemptMinIds, |
| List<Integer> vertexMinIds, List<Integer> taskMinIds, Integer limit) { |
| //Creating mock DAG |
| DAG mockDAG = mock(DAG.class); |
| doReturn(TezDAGID.fromString("dag_1441301219877_0109_1")).when(mockDAG).getID(); |
| |
| //Creating mock vertex and attaching to mock DAG |
| TezVertexID vertexID = TezVertexID.fromString("vertex_1441301219877_0109_1_00"); |
| Vertex mockVertex = mock(Vertex.class); |
| doReturn(vertexID).when(mockVertex).getVertexId(); |
| |
| doReturn(mockVertex).when(mockDAG).getVertex(vertexID); |
| doReturn(ImmutableMap.of( |
| vertexID, mockVertex |
| )).when(mockDAG).getVertices(); |
| |
| //Creating mock task and attaching to mock Vertex |
| TezTaskID taskID = TezTaskID.fromString("task_1441301219877_0109_1_00_000000"); |
| Task mockTask = mock(Task.class); |
| doReturn(taskID).when(mockTask).getTaskID(); |
| int taskIndex = taskID.getId(); |
| doReturn(mockTask).when(mockVertex).getTask(taskIndex); |
| doReturn(ImmutableMap.of( |
| taskID, mockTask |
| )).when(mockVertex).getTasks(); |
| |
| //Creating mock tasks and attaching to mock vertex |
| Map<TezTaskAttemptID, TaskAttempt> attemptsMap = Maps.newHashMap(); |
| for(TaskAttempt attempt : attempts) { |
| TezTaskAttemptID attemptId = attempt.getTaskAttemptID(); |
| doReturn(attempt).when(mockTask).getAttempt(attemptId); |
| attemptsMap.put(attemptId, attempt); |
| } |
| doReturn(attemptsMap).when(mockTask).getAttempts(); |
| |
| //Creates & setup controller spy |
| AMWebController amWebController = new AMWebController(mockRequestContext, mockAppContext, |
| "TEST_HISTORY_URL"); |
| AMWebController spy = spy(amWebController); |
| doReturn(true).when(spy).setupResponse(); |
| doNothing().when(spy).renderJSON(any()); |
| |
| // Set mock query params |
| doReturn(limit).when(spy).getQueryParamInt(WebUIService.LIMIT); |
| doReturn(vertexMinIds).when(spy).getIntegersFromRequest(WebUIService.VERTEX_ID, limit); |
| doReturn(taskMinIds).when(spy).getIDsFromRequest(WebUIService.TASK_ID, limit, 2); |
| doReturn(attemptMinIds).when(spy).getIDsFromRequest(WebUIService.ATTEMPT_ID, limit, 3); |
| |
| // Set function mocks |
| doReturn(mockDAG).when(spy).checkAndGetDAGFromRequest(); |
| |
| Map<String, Set<String>> counterList = new TreeMap<String, Set<String>>(); |
| doReturn(counterList).when(spy).getCounterListFromRequest(); |
| |
| spy.getAttemptsInfo(); |
| verify(spy).renderJSON(returnResultCaptor.capture()); |
| |
| return returnResultCaptor.getValue(); |
| } |
| |
| private List<TaskAttempt> createMockAttempts() { |
| TaskAttempt mockAttempt1 = createMockAttempt("attempt_1441301219877_0109_1_00_000000_0", TaskAttemptState.RUNNING, |
| 0.33f); |
| TaskAttempt mockAttempt2 = createMockAttempt("attempt_1441301219877_0109_1_00_000000_1", TaskAttemptState.SUCCEEDED, |
| 1.0f); |
| TaskAttempt mockAttempt3 = createMockAttempt("attempt_1441301219877_0109_1_00_000000_2", TaskAttemptState.FAILED, |
| .8f); |
| TaskAttempt mockAttempt4 = createMockAttempt("attempt_1441301219877_0109_1_00_000000_3", TaskAttemptState.SUCCEEDED, |
| .8f); |
| |
| List <TaskAttempt> attempts = Arrays.asList(mockAttempt1, mockAttempt2, mockAttempt3, mockAttempt4); |
| return attempts; |
| } |
| |
| private TaskAttempt createMockAttempt(String attemptIDStr, TaskAttemptState status, float progress) { |
| TaskAttempt mockAttempt = mock(TaskAttempt.class); |
| |
| doReturn(TezTaskAttemptID.fromString(attemptIDStr)).when(mockAttempt).getTaskAttemptID(); |
| doReturn(status).when(mockAttempt).getState(); |
| doReturn(progress).when(mockAttempt).getProgress(); |
| |
| TezCounters counters = new TezCounters(); |
| counters.addGroup("g1", "g1"); |
| counters.addGroup("g2", "g2"); |
| counters.addGroup("g3", "g3"); |
| counters.addGroup("g4", "g4"); |
| counters.findCounter("g1", "g1_c1").setValue(101); |
| counters.findCounter("g1", "g1_c2").setValue(102); |
| counters.findCounter("g2", "g2_c3").setValue(103); |
| counters.findCounter("g2", "g2_c4").setValue(104); |
| counters.findCounter("g3", "g3_c5").setValue(105); |
| counters.findCounter("g3", "g3_c6").setValue(106); |
| |
| doReturn(counters).when(mockAttempt).getCounters(); |
| |
| return mockAttempt; |
| } |
| |
| private void verifySingleAttemptResult(TaskAttempt mockTask, Map<String, String> taskResult) { |
| Assert.assertEquals(3, taskResult.size()); |
| Assert.assertEquals(mockTask.getTaskAttemptID().toString(), taskResult.get("id")); |
| Assert.assertEquals(mockTask.getState().toString(), taskResult.get("status")); |
| Assert.assertEquals(Float.toString(mockTask.getProgress()), taskResult.get("progress")); |
| } |
| |
| } |