blob: 7deed4846f876fe70432bb95db4be1a83a2ccc20 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.api.client.rpc;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.Vertex;
import org.apache.tez.dag.api.client.DAGStatus;
import org.apache.tez.dag.api.client.StatusGetOpts;
import org.apache.tez.dag.api.client.VertexStatus;
import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetDAGStatusRequestProto;
import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetDAGStatusResponseProto;
import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetVertexStatusRequestProto;
import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetVertexStatusResponseProto;
import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.TryKillDAGRequestProto;
import org.apache.tez.dag.api.records.DAGProtos.DAGStatusProto;
import org.apache.tez.dag.api.records.DAGProtos.DAGStatusStateProto;
import org.apache.tez.dag.api.records.DAGProtos.ProgressProto;
import org.apache.tez.dag.api.records.DAGProtos.StatusGetOptsProto;
import org.apache.tez.dag.api.records.DAGProtos.StringProgressPairProto;
import org.apache.tez.dag.api.records.DAGProtos.TezCounterGroupProto;
import org.apache.tez.dag.api.records.DAGProtos.TezCounterProto;
import org.apache.tez.dag.api.records.DAGProtos.TezCountersProto;
import org.apache.tez.dag.api.records.DAGProtos.VertexStatusProto;
import org.apache.tez.dag.api.records.DAGProtos.VertexStatusStateProto;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatcher;
import org.mockito.internal.util.collections.Sets;
import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException;
public class TestDAGClient {
private DAGClientRPCImpl dagClient;
private ApplicationId mockAppId;
private ApplicationReport mockAppReport;
private String dagIdStr;
private DAGClientAMProtocolBlockingPB mockProxy;
private VertexStatusProto vertexStatusProtoWithoutCounters;
private VertexStatusProto vertexStatusProtoWithCounters;
private DAGStatusProto dagStatusProtoWithoutCounters;
private DAGStatusProto dagStatusProtoWithCounters;
private void setUpData(){
// DAG
ProgressProto dagProgressProto = ProgressProto.newBuilder()
.setFailedTaskCount(1)
.setKilledTaskCount(1)
.setRunningTaskCount(2)
.setSucceededTaskCount(2)
.setTotalTaskCount(6)
.build();
TezCountersProto dagCountersProto=TezCountersProto.newBuilder()
.addCounterGroups(TezCounterGroupProto.newBuilder()
.setName("DAGGroup")
.addCounters(TezCounterProto.newBuilder()
.setDisplayName("dag_counter_1")
.setValue(99)))
.build();
dagStatusProtoWithoutCounters = DAGStatusProto.newBuilder()
.addDiagnostics("Diagnostics_0")
.setState(DAGStatusStateProto.DAG_RUNNING)
.setDAGProgress(dagProgressProto)
.addVertexProgress(
StringProgressPairProto.newBuilder().setKey("v1")
.setProgress(ProgressProto.newBuilder()
.setFailedTaskCount(0)
.setSucceededTaskCount(0)
.setKilledTaskCount(0))
)
.addVertexProgress(
StringProgressPairProto.newBuilder().setKey("v2")
.setProgress(ProgressProto.newBuilder()
.setFailedTaskCount(1)
.setSucceededTaskCount(1)
.setKilledTaskCount(1))
)
.build();
dagStatusProtoWithCounters = DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters)
.setDagCounters(dagCountersProto)
.build();
// Vertex
ProgressProto vertexProgressProto = ProgressProto.newBuilder()
.setFailedTaskCount(1)
.setKilledTaskCount(0)
.setRunningTaskCount(0)
.setSucceededTaskCount(1)
.build();
TezCountersProto vertexCountersProto=TezCountersProto.newBuilder()
.addCounterGroups(TezCounterGroupProto.newBuilder()
.addCounters(TezCounterProto.newBuilder()
.setDisplayName("vertex_counter_1")
.setValue(99)))
.build();
vertexStatusProtoWithoutCounters = VertexStatusProto.newBuilder()
.addDiagnostics("V_Diagnostics_0")
.setProgress(vertexProgressProto)
.setState(VertexStatusStateProto.VERTEX_SUCCEEDED) // make sure the waitForCompletion be able to finish
.build();
vertexStatusProtoWithCounters = VertexStatusProto.newBuilder(vertexStatusProtoWithoutCounters)
.setVertexCounters(vertexCountersProto)
.build();
}
private static class DAGCounterRequestMatcher extends ArgumentMatcher<GetDAGStatusRequestProto>{
@Override
public boolean matches(Object argument) {
if (argument instanceof GetDAGStatusRequestProto){
GetDAGStatusRequestProto requestProto = (GetDAGStatusRequestProto)argument;
return requestProto.getStatusOptionsCount() != 0
&& requestProto.getStatusOptionsList().get(0) == StatusGetOptsProto.GET_COUNTERS;
}
return false;
}
}
private static class VertexCounterRequestMatcher extends ArgumentMatcher<GetVertexStatusRequestProto>{
@Override
public boolean matches(Object argument) {
if (argument instanceof GetVertexStatusRequestProto){
GetVertexStatusRequestProto requestProto = (GetVertexStatusRequestProto)argument;
return requestProto.getStatusOptionsCount() != 0
&& requestProto.getStatusOptionsList().get(0) == StatusGetOptsProto.GET_COUNTERS;
}
return false;
}
}
@Before
public void setUp() throws YarnException, IOException, TezException, ServiceException{
setUpData();
/////////////// mock //////////////////////
mockAppId = mock(ApplicationId.class);
mockAppReport = mock(ApplicationReport.class);
dagIdStr = "dag_9999_0001_1";
mockProxy = mock(DAGClientAMProtocolBlockingPB.class);
// return the response with Counters is the request match the CounterMatcher
when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build());
when(mockProxy.getDAGStatus(isNull(RpcController.class), argThat(new DAGCounterRequestMatcher())))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithCounters).build());
when(mockProxy.getVertexStatus(isNull(RpcController.class), any(GetVertexStatusRequestProto.class)))
.thenReturn(GetVertexStatusResponseProto.newBuilder().setVertexStatus(vertexStatusProtoWithoutCounters).build());
when(mockProxy.getVertexStatus(isNull(RpcController.class), argThat(new VertexCounterRequestMatcher())))
.thenReturn(GetVertexStatusResponseProto.newBuilder().setVertexStatus(vertexStatusProtoWithCounters).build());
dagClient = new DAGClientRPCImpl(mockAppId, dagIdStr, new TezConfiguration());
dagClient.appReport = mockAppReport;
((DAGClientRPCImpl)dagClient).proxy = mockProxy;
}
@Test
public void testApp() throws IOException, TezException, ServiceException{
assertEquals(mockAppId, dagClient.getApplicationId());
assertEquals(mockAppReport, dagClient.getApplicationReport());
}
@Test
public void testDAGStatus() throws Exception{
DAGStatus resultDagStatus = dagClient.getDAGStatus(null);
verify(mockProxy, times(1)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).build());
assertEquals(new DAGStatus(dagStatusProtoWithoutCounters), resultDagStatus);
System.out.println("DAGStatusWithoutCounter:" + resultDagStatus);
resultDagStatus = dagClient.getDAGStatus(Sets.newSet(StatusGetOpts.GET_COUNTERS));
verify(mockProxy, times(1)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).addStatusOptions(StatusGetOptsProto.GET_COUNTERS).build());
assertEquals(new DAGStatus(dagStatusProtoWithCounters), resultDagStatus);
System.out.println("DAGStatusWithCounter:" + resultDagStatus);
}
@Test
public void testVertexStatus() throws Exception{
VertexStatus resultVertexStatus = dagClient.getVertexStatus("v1", null);
verify(mockProxy).getVertexStatus(null, GetVertexStatusRequestProto.newBuilder()
.setDagId(dagIdStr).setVertexName("v1").build());
assertEquals(new VertexStatus(vertexStatusProtoWithoutCounters), resultVertexStatus);
System.out.println("VertexWithoutCounter:" + resultVertexStatus);
resultVertexStatus = dagClient.getVertexStatus("v1", Sets.newSet(StatusGetOpts.GET_COUNTERS));
verify(mockProxy).getVertexStatus(null, GetVertexStatusRequestProto.newBuilder()
.setDagId(dagIdStr).setVertexName("v1").addStatusOptions(StatusGetOptsProto.GET_COUNTERS).build());
assertEquals(new VertexStatus(vertexStatusProtoWithCounters), resultVertexStatus);
System.out.println("VertexWithCounter:" + resultVertexStatus);
}
@Test
public void testTryKillDAG() throws Exception{
dagClient.tryKillDAG();
verify(mockProxy, times(1)).tryKillDAG(null, TryKillDAGRequestProto.newBuilder()
.setDagId(dagIdStr).build());
}
@Test
public void testWaitForCompletion() throws Exception{
// first time return DAG_RUNNING, second time return DAG_SUCCEEDED
when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
(DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
.build());
dagClient.waitForCompletion();
verify(mockProxy, times(2)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).build());
}
@Test
public void testWaitForCompletionWithStatusUpdates() throws Exception{
// first time return DAG_RUNNING, second time return DAG_SUCCEEDED
when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
(DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
.build());
dagClient.waitForCompletionWithStatusUpdates(null, null);
verify(mockProxy, times(2)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).build());
// first time return DAG_RUNNING, second time return DAG_SUCCEEDED
when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
(DAGStatusProto.newBuilder(dagStatusProtoWithCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
.build());
dagClient.waitForCompletionWithStatusUpdates(Sets.newSet(new Vertex("v1",null, 1, Resource.newInstance(1024, 1))),
Sets.newSet(StatusGetOpts.GET_COUNTERS));
verify(mockProxy, times(2)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).addStatusOptions(StatusGetOptsProto.GET_COUNTERS).build());
}
@Test
public void testWaitForCompletionWithAllStatusUpdates() throws Exception{
// first time and second time return DAG_RUNNING, third time return DAG_SUCCEEDED
when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
(DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
.build());
// first time for getVertexSet
// second & third time for check completion
dagClient.waitForCompletionWithAllStatusUpdates(null);
verify(mockProxy, times(3)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).build());
when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithCounters).build())
.thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
(DAGStatusProto.newBuilder(dagStatusProtoWithCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
.build());
dagClient.waitForCompletionWithAllStatusUpdates(Sets.newSet(StatusGetOpts.GET_COUNTERS));
verify(mockProxy, times(3)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
.setDagId(dagIdStr).addStatusOptions(StatusGetOptsProto.GET_COUNTERS).build());
}
}