| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.dag.app.dag; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.fail; |
| import static org.mockito.Matchers.any; |
| import static org.mockito.Mockito.doReturn; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.never; |
| import static org.mockito.Mockito.times; |
| import static org.mockito.Mockito.verify; |
| |
| import java.util.EnumSet; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import com.google.common.collect.Lists; |
| |
| import org.apache.tez.dag.api.TezUncheckedException; |
| import org.apache.tez.dag.api.event.VertexState; |
| import org.apache.tez.dag.api.event.VertexStateUpdate; |
| import org.apache.tez.dag.api.event.VertexStateUpdateParallelismUpdated; |
| import org.apache.tez.dag.records.TezDAGID; |
| import org.apache.tez.dag.records.TezVertexID; |
| import org.junit.Assert; |
| import org.junit.Test; |
| import org.mockito.ArgumentCaptor; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| public class TestStateChangeNotifier { |
| |
| // uses the thread based notification code path but effectively blocks update |
| // events till listeners have been notified |
| public static class StateChangeNotifierForTest extends StateChangeNotifier { |
| private static final Logger LOG = LoggerFactory.getLogger(StateChangeNotifierForTest.class); |
| AtomicInteger count = new AtomicInteger(0); |
| AtomicInteger totalCount = new AtomicInteger(0); |
| |
| public StateChangeNotifierForTest(DAG dag) { |
| super(dag); |
| } |
| |
| public void reset() { |
| count.set(0); |
| totalCount.set(0); |
| } |
| |
| @Override |
| protected void processedEventFromQueue() { |
| // addedEventToQueue runs in dispatcher thread while |
| // processedEventFromQueue runs in state change notifier event handling thread. |
| // It is not guaranteed that addedEventToQueue is invoked before processedEventFromQueue. |
| // so sleep here until there's available events |
| while(count.get() <=0) { |
| try { |
| Thread.sleep(10); |
| LOG.info("sleep to wait for available events"); |
| } catch (InterruptedException e) { |
| e.printStackTrace(); |
| } |
| } |
| synchronized (count) { |
| if (count.decrementAndGet() == 0) { |
| count.notifyAll(); |
| } |
| } |
| } |
| |
| @Override |
| protected void addedEventToQueue() { |
| totalCount.incrementAndGet(); |
| synchronized (count) { |
| // processing may finish by the time we get here |
| if (count.incrementAndGet() > 0) { |
| try { |
| count.wait(); |
| } catch (InterruptedException e) { |
| e.printStackTrace(); |
| } |
| } |
| } |
| } |
| } |
| |
| @Test(timeout = 5000) |
| public void testEventsOnRegistration() { |
| TezDAGID dagId = TezDAGID.getInstance("1", 1, 1); |
| Vertex v1 = createMockVertex(dagId, 1); |
| Vertex v2 = createMockVertex(dagId, 2); |
| Vertex v3 = createMockVertex(dagId, 3); |
| DAG dag = createMockDag(dagId, v1, v2, v3); |
| |
| StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag); |
| |
| // Vertex has sent one event |
| notifyTracker(tracker, v1, VertexState.RUNNING); |
| VertexStateUpdateListener mockListener11 = mock(VertexStateUpdateListener.class); |
| VertexStateUpdateListener mockListener12 = mock(VertexStateUpdateListener.class); |
| VertexStateUpdateListener mockListener13 = mock(VertexStateUpdateListener.class); |
| VertexStateUpdateListener mockListener14 = mock(VertexStateUpdateListener.class); |
| // Register for all states |
| tracker.registerForVertexUpdates(v1.getName(), null, mockListener11); |
| // Register for all states |
| tracker.registerForVertexUpdates(v1.getName(), EnumSet.allOf( |
| VertexState.class), mockListener12); |
| // Register for specific state, event generated |
| tracker.registerForVertexUpdates(v1.getName(), EnumSet.of( |
| VertexState.RUNNING), mockListener13); |
| // Register for specific state, event not generated |
| tracker.registerForVertexUpdates(v1.getName(), EnumSet.of( |
| VertexState.SUCCEEDED), mockListener14); |
| ArgumentCaptor<VertexStateUpdate> argumentCaptor = |
| ArgumentCaptor.forClass(VertexStateUpdate.class); |
| verify(mockListener11, times(1)).onStateUpdated(argumentCaptor.capture()); |
| assertEquals(VertexState.RUNNING, |
| argumentCaptor.getValue().getVertexState()); |
| verify(mockListener12, times(1)).onStateUpdated(argumentCaptor.capture()); |
| assertEquals(VertexState.RUNNING, |
| argumentCaptor.getValue().getVertexState()); |
| verify(mockListener13, times(1)).onStateUpdated(argumentCaptor.capture()); |
| assertEquals(VertexState.RUNNING, |
| argumentCaptor.getValue().getVertexState()); |
| verify(mockListener14, never()).onStateUpdated(any(VertexStateUpdate.class)); |
| |
| // Vertex has not notified of state |
| tracker.reset(); |
| VertexStateUpdateListener mockListener2 = mock(VertexStateUpdateListener.class); |
| tracker.registerForVertexUpdates(v2.getName(), null, mockListener2); |
| Assert.assertEquals(0, tracker.totalCount.get()); // there should no be any event sent out |
| verify(mockListener2, never()).onStateUpdated(any(VertexStateUpdate.class)); |
| |
| // Vertex has notified about parallelism update only |
| tracker.stateChanged(v3.getVertexId(), new VertexStateUpdateParallelismUpdated(v3.getName(), 23, -1)); |
| VertexStateUpdateListener mockListener3 = mock(VertexStateUpdateListener.class); |
| tracker.registerForVertexUpdates(v3.getName(), null, mockListener3); |
| verify(mockListener3, times(1)).onStateUpdated(argumentCaptor.capture()); |
| assertEquals(VertexState.PARALLELISM_UPDATED, |
| argumentCaptor.getValue().getVertexState()); |
| } |
| |
| @Test(timeout = 5000) |
| public void testSimpleStateUpdates() { |
| TezDAGID dagId = TezDAGID.getInstance("1", 1, 1); |
| Vertex v1 = createMockVertex(dagId, 1); |
| DAG dag = createMockDag(dagId, v1); |
| |
| StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag); |
| |
| VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class); |
| tracker.registerForVertexUpdates(v1.getName(), null, mockListener); |
| |
| List<VertexState> expectedStates = Lists.newArrayList( |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED, |
| VertexState.FAILED, |
| VertexState.KILLED, |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED); |
| |
| for (VertexState state : expectedStates) { |
| notifyTracker(tracker, v1, state); |
| } |
| |
| ArgumentCaptor<VertexStateUpdate> argumentCaptor = |
| ArgumentCaptor.forClass(VertexStateUpdate.class); |
| verify(mockListener, times(expectedStates.size())).onStateUpdated(argumentCaptor.capture()); |
| List<VertexStateUpdate> stateUpdatesSent = argumentCaptor.getAllValues(); |
| |
| Iterator<VertexState> expectedStateIter = |
| expectedStates.iterator(); |
| for (int i = 0; i < expectedStates.size(); i++) { |
| assertEquals(expectedStateIter.next(), stateUpdatesSent.get(i).getVertexState()); |
| } |
| } |
| |
| @Test(timeout = 5000) |
| public void testDuplicateRegistration() { |
| TezDAGID dagId = TezDAGID.getInstance("1", 1, 1); |
| Vertex v1 = createMockVertex(dagId, 1); |
| DAG dag = createMockDag(dagId, v1); |
| |
| StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag); |
| VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class); |
| |
| tracker.registerForVertexUpdates(v1.getName(), null, mockListener); |
| try { |
| tracker.registerForVertexUpdates(v1.getName(), null, mockListener); |
| fail("Expecting an error from duplicate registrations of the same listener"); |
| } catch (TezUncheckedException e) { |
| // Expected, ignore |
| } |
| } |
| |
| @Test(timeout = 5000) |
| public void testSpecificStateUpdates() { |
| TezDAGID dagId = TezDAGID.getInstance("1", 1, 1); |
| Vertex v1 = createMockVertex(dagId, 1); |
| DAG dag = createMockDag(dagId, v1); |
| |
| StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag); |
| |
| VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class); |
| tracker.registerForVertexUpdates(v1.getName(), EnumSet.of( |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED), mockListener); |
| |
| List<VertexState> states = Lists.newArrayList( |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED, |
| VertexState.FAILED, |
| VertexState.KILLED, |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED); |
| List<VertexState> expectedStates = Lists.newArrayList( |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED, |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED); |
| |
| for (VertexState state : states) { |
| notifyTracker(tracker, v1, state); |
| } |
| |
| ArgumentCaptor<VertexStateUpdate> argumentCaptor = |
| ArgumentCaptor.forClass(VertexStateUpdate.class); |
| verify(mockListener, times(expectedStates.size())).onStateUpdated(argumentCaptor.capture()); |
| List<VertexStateUpdate> stateUpdatesSent = argumentCaptor.getAllValues(); |
| |
| Iterator<VertexState> expectedStateIter = |
| expectedStates.iterator(); |
| for (int i = 0; i < expectedStates.size(); i++) { |
| assertEquals(expectedStateIter.next(), stateUpdatesSent.get(i).getVertexState()); |
| } |
| } |
| |
| @Test(timeout = 5000) |
| public void testUnregister() { |
| TezDAGID dagId = TezDAGID.getInstance("1", 1, 1); |
| Vertex v1 = createMockVertex(dagId, 1); |
| DAG dag = createMockDag(dagId, v1); |
| |
| StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag); |
| |
| VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class); |
| tracker.registerForVertexUpdates(v1.getName(), null, mockListener); |
| |
| List<VertexState> expectedStates = Lists.newArrayList( |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED, |
| VertexState.FAILED, |
| VertexState.KILLED, |
| VertexState.RUNNING, |
| VertexState.SUCCEEDED); |
| |
| int count = 0; |
| int numExpectedEvents = 3; |
| for (VertexState state : expectedStates) { |
| if (count == numExpectedEvents) { |
| tracker.unregisterForVertexUpdates(v1.getName(), mockListener); |
| } |
| notifyTracker(tracker, v1, state); |
| count++; |
| } |
| |
| ArgumentCaptor<VertexStateUpdate> argumentCaptor = |
| ArgumentCaptor.forClass(VertexStateUpdate.class); |
| verify(mockListener, times(numExpectedEvents)).onStateUpdated(argumentCaptor.capture()); |
| List<VertexStateUpdate> stateUpdatesSent = argumentCaptor.getAllValues(); |
| |
| Iterator<VertexState> expectedStateIter = |
| expectedStates.iterator(); |
| for (int i = 0; i < numExpectedEvents; i++) { |
| assertEquals(expectedStateIter.next(), stateUpdatesSent.get(i).getVertexState()); |
| } |
| } |
| |
| private DAG createMockDag(TezDAGID dagId, Vertex... vertices) { |
| DAG dag = mock(DAG.class); |
| doReturn(dagId).when(dag).getID(); |
| for (Vertex v : vertices) { |
| String vertexName = v.getName(); |
| TezVertexID vertexId = v.getVertexId(); |
| |
| doReturn(v).when(dag).getVertex(vertexName); |
| doReturn(v).when(dag).getVertex(vertexId); |
| } |
| return dag; |
| } |
| |
| private Vertex createMockVertex(TezDAGID dagId, int id) { |
| TezVertexID vertexId = TezVertexID.getInstance(dagId, id); |
| String vertexName = "vertex" + id; |
| Vertex v = mock(Vertex.class); |
| doReturn(vertexId).when(v).getVertexId(); |
| doReturn(vertexName).when(v).getName(); |
| return v; |
| } |
| |
| private void notifyTracker(StateChangeNotifier notifier, Vertex v, |
| VertexState state) { |
| notifier.stateChanged(v.getVertexId(), new VertexStateUpdate(v.getName(), state)); |
| } |
| } |