blob: b93b29878ef7d57fa7ed8fb640cc1e938392d949 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.app.dag.impl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.dag.StateChangeNotifier;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.app.dag.event.CallableEvent;
import org.apache.tez.dag.app.dag.event.VertexEventInputDataInformation;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
@SuppressWarnings({ "rawtypes", "unchecked" })
public class TestVertexManager {
AppContext mockAppContext;
ListeningExecutorService execService;
Vertex mockVertex;
EventHandler mockHandler;
ArgumentCaptor<VertexEventInputDataInformation> requestCaptor;
@Before
public void setup() {
mockAppContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
execService = mock(ListeningExecutorService.class);
final ListenableFuture<Void> mockFuture = mock(ListenableFuture.class);
Mockito.doAnswer(new Answer() {
public ListenableFuture<Void> answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
CallableEvent e = (CallableEvent) args[0];
new CallableEventDispatcher().handle(e);
return mockFuture;
}})
.when(execService).submit((Callable<Void>) any());
doReturn(execService).when(mockAppContext).getExecService();
mockVertex = mock(Vertex.class, RETURNS_DEEP_STUBS);
doReturn("vertex1").when(mockVertex).getName();
mockHandler = mock(EventHandler.class);
when(mockAppContext.getEventHandler()).thenReturn(mockHandler);
when(
mockAppContext.getCurrentDAG().getVertex(any(String.class))
.getTotalTasks()).thenReturn(1);
requestCaptor = ArgumentCaptor.forClass(VertexEventInputDataInformation.class);
}
public static class CheckUserPayloadVertexManagerPlugin extends VertexManagerPlugin {
public CheckUserPayloadVertexManagerPlugin(VertexManagerPluginContext context) {
super(context);
assertNotNull(context.getUserPayload());
}
@Override
public void initialize() throws Exception {}
@Override
public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) throws Exception {}
@Override
public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor,
List<Event> events) throws Exception {}
}
@Test(timeout = 5000)
public void testVertexManagerPluginCtorAccessUserPayload() throws IOException, TezException {
byte[] randomUserPayload = {1,2,3};
UserPayload userPayload = UserPayload.create(ByteBuffer.wrap(randomUserPayload));
VertexManager vm =
new VertexManager(
VertexManagerPluginDescriptor.create(CheckUserPayloadVertexManagerPlugin.class
.getName()).setUserPayload(userPayload), UserGroupInformation.getCurrentUser(),
mockVertex, mockAppContext, mock(StateChangeNotifier.class));
}
@Test(timeout = 5000)
public void testOnRootVertexInitialized() throws Exception {
VertexManager vm =
new VertexManager(
VertexManagerPluginDescriptor.create(RootInputVertexManager.class
.getName()), UserGroupInformation.getCurrentUser(),
mockVertex, mockAppContext, mock(StateChangeNotifier.class));
vm.initialize();
InputDescriptor id1 = mock(InputDescriptor.class);
List<Event> events1 = new LinkedList<Event>();
InputDataInformationEvent diEvent1 =
InputDataInformationEvent.createWithSerializedPayload(0, null);
events1.add(diEvent1);
vm.onRootVertexInitialized("input1", id1, events1);
verify(mockHandler, times(1)).handle(requestCaptor.capture());
List<TezEvent> tezEvents1 = requestCaptor.getValue().getEvents();
assertEquals(1, tezEvents1.size());
assertEquals(diEvent1, tezEvents1.get(0).getEvent());
InputDescriptor id2 = mock(InputDescriptor.class);
List<Event> events2 = new LinkedList<Event>();
InputDataInformationEvent diEvent2 =
InputDataInformationEvent.createWithSerializedPayload(0, null);
events2.add(diEvent2);
vm.onRootVertexInitialized("input1", id2, events2);
verify(mockHandler, times(2)).handle(requestCaptor.capture());
List<TezEvent> tezEvents2 = requestCaptor.getValue().getEvents();
assertEquals(tezEvents2.size(), 1);
assertEquals(diEvent2, tezEvents2.get(0).getEvent());
}
/**
* TEZ-1647
* custom vertex manager generates events only when both i1 and i2 are initialized.
* @throws Exception
*/
@Test(timeout = 5000)
public void testOnRootVertexInitialized2() throws Exception {
VertexManager vm =
new VertexManager(
VertexManagerPluginDescriptor.create(CustomVertexManager.class
.getName()), UserGroupInformation.getCurrentUser(),
mockVertex, mockAppContext, mock(StateChangeNotifier.class));
vm.initialize();
InputDescriptor id1 = mock(InputDescriptor.class);
List<Event> events1 = new LinkedList<Event>();
InputDataInformationEvent diEvent1 =
InputDataInformationEvent.createWithSerializedPayload(0, null);
events1.add(diEvent1);
// do not call context.addRootInputEvents, just cache the TezEvent
vm.onRootVertexInitialized("input1", id1, events1);
verify(mockHandler, times(1)).handle(requestCaptor.capture());
List<TezEvent> tezEventsAfterInput1 = requestCaptor.getValue().getEvents();
assertEquals(0, tezEventsAfterInput1.size());
InputDescriptor id2 = mock(InputDescriptor.class);
List<Event> events2 = new LinkedList<Event>();
InputDataInformationEvent diEvent2 =
InputDataInformationEvent.createWithSerializedPayload(0, null);
events2.add(diEvent2);
// call context.addRootInputEvents(input1), context.addRootInputEvents(input2)
vm.onRootVertexInitialized("input2", id2, events2);
verify(mockHandler, times(2)).handle(requestCaptor.capture());
List<TezEvent> tezEventsAfterInput2 = requestCaptor.getValue().getEvents();
assertEquals(2, tezEventsAfterInput2.size());
// also verify the EventMetaData
Set<String> edgeVertexSet = new HashSet<String>();
for (TezEvent tezEvent : tezEventsAfterInput2) {
edgeVertexSet.add(tezEvent.getDestinationInfo().getEdgeVertexName());
}
assertEquals(Sets.newHashSet("input1","input2"), edgeVertexSet);
}
public static class CustomVertexManager extends VertexManagerPlugin {
private Map<String,List<Event>> cachedEventMap = new HashMap<String, List<Event>>();
public CustomVertexManager(VertexManagerPluginContext context) {
super(context);
}
@Override
public void initialize() {
}
@Override
public void onVertexStarted(List<TaskAttemptIdentifier> completions) {
}
@Override
public void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
}
@Override
public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
}
/**
* only addRootInputEvents when it is "input2", otherwise just cache it.
*/
@Override
public void onRootVertexInitialized(String inputName,
InputDescriptor inputDescriptor, List<Event> events) {
cachedEventMap.put(inputName, events);
if (inputName.equals("input2")) {
for (Map.Entry<String, List<Event>> entry : cachedEventMap.entrySet()) {
List<InputDataInformationEvent> riEvents = Lists.newLinkedList();
for (Event event : events) {
riEvents.add((InputDataInformationEvent)event);
}
getContext().addRootInputEvents(entry.getKey(), riEvents);
}
}
}
}
}