MAPREDUCE-2224. Fix synchronization bugs in JvmManager. Contributed by Todd Lipcon

git-svn-id: https://svn.apache.org/repos/asf/hadoop/mapreduce/trunk@1053263 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/CHANGES.txt b/CHANGES.txt
index 5c15a8a..abd38e0 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -451,6 +451,8 @@
     MAPREDUCE-1783. FairScheduler initializes tasks only when the job can be
     run. (Ramkumar Vadali via schen)
 
+    MAPREDUCE-2224. Fix synchronization bugs in JvmManager. (todd)
+
 Release 0.21.1 - Unreleased
 
   NEW FEATURES
diff --git a/src/java/org/apache/hadoop/mapred/JvmManager.java b/src/java/org/apache/hadoop/mapred/JvmManager.java
index 957cbf0..9cbbd5b 100644
--- a/src/java/org/apache/hadoop/mapred/JvmManager.java
+++ b/src/java/org/apache/hadoop/mapred/JvmManager.java
@@ -87,10 +87,10 @@
    */
   void setPidToJvm(JVMId jvmId, String pid) {
     if (jvmId.isMapJVM()) {
-      mapJvmManager.jvmIdToPid.put(jvmId, pid);
+      mapJvmManager.setPidForJvm(jvmId, pid);
     }
     else {
-      reduceJvmManager.jvmIdToPid.put(jvmId, pid);
+      reduceJvmManager.setPidForJvm(jvmId, pid);
     }
   }
   
@@ -100,15 +100,9 @@
   String getPid(TaskRunner t) {
     if (t != null && t.getTask() != null) {
       if (t.getTask().isMapTask()) {
-        JVMId id = mapJvmManager.runningTaskToJvm.get(t);
-        if (id != null) {
-          return mapJvmManager.jvmIdToPid.get(id);
-        }
+        return mapJvmManager.getPidByRunningTask(t);
       } else {
-        JVMId id = reduceJvmManager.runningTaskToJvm.get(t);
-        if (id != null) {
-          return reduceJvmManager.jvmIdToPid.get(id);
-        }
+        return reduceJvmManager.getPidByRunningTask(t);
       }
     }
     return null;
@@ -188,9 +182,6 @@
     //Mapping from the JVM IDs to Reduce JVM processes
     Map <JVMId, JvmRunner> jvmIdToRunner = 
       new HashMap<JVMId, JvmRunner>();
-    //Mapping from the JVM IDs to process IDs
-    Map <JVMId, String> jvmIdToPid = 
-      new HashMap<JVMId, String>();
     
     int maxJvms;
     boolean isMap;
@@ -210,7 +201,7 @@
         TaskRunner t) {
       jvmToRunningTask.put(jvmId, t);
       runningTaskToJvm.put(t,jvmId);
-      jvmIdToRunner.get(jvmId).setBusy(true);
+      jvmIdToRunner.get(jvmId).setTaskRunner(t);
     }
     
     synchronized public TaskInProgress getTaskForJvm(JVMId jvmId)
@@ -246,6 +237,20 @@
       }
       return null;
     }
+
+    synchronized String getPidByRunningTask(TaskRunner t) {
+      JVMId id = runningTaskToJvm.get(t);
+      if (id != null) {
+        return jvmIdToRunner.get(id).getPid();
+      }
+      return null;
+    }
+
+    synchronized void setPidForJvm(JVMId jvmId, String pid) {
+      JvmRunner runner = jvmIdToRunner.get(jvmId);
+      assert runner != null : "Task must have a runner to set a pid";
+      runner.setPid(pid);
+    }
     
     synchronized public boolean isJvmknown(JVMId jvmId) {
       return jvmIdToRunner.containsKey(jvmId);
@@ -282,14 +287,20 @@
       removeJvm(jvmRunner.jvmId);
     }
 
-    synchronized void dumpStack(TaskRunner tr) {
-      JVMId jvmId = runningTaskToJvm.get(tr);
-      if (null != jvmId) {
-        JvmRunner jvmRunner = jvmIdToRunner.get(jvmId);
-        if (null != jvmRunner) {
-          jvmRunner.dumpChildStacks();
+    void dumpStack(TaskRunner tr) {
+      JvmRunner jvmRunner = null;
+      synchronized (this) {
+        JVMId jvmId = runningTaskToJvm.get(tr);
+        if (null != jvmId) {
+          jvmRunner = jvmIdToRunner.get(jvmId);
         }
       }
+
+      // Don't want to hold JvmManager lock while dumping stacks for one
+      // task.
+      if (null != jvmRunner) {
+        jvmRunner.dumpChildStacks();
+      }
     }
 
     synchronized public void stop() {
@@ -307,7 +318,6 @@
     
     synchronized private void removeJvm(JVMId jvmId) {
       jvmIdToRunner.remove(jvmId);
-      jvmIdToPid.remove(jvmId);
     }
     private synchronized void reapJvm( 
         TaskRunner t, JvmEnv env) {
@@ -377,7 +387,7 @@
             " " + getDetails());
     }
     
-    private String getDetails() {
+    private synchronized String getDetails() {
       StringBuffer details = new StringBuffer();
       details.append("Number of active JVMs:").
               append(jvmIdToRunner.size());
@@ -390,14 +400,14 @@
           append(" #Tasks ran: "). 
           append(jvmIdToRunner.get(jvmId).numTasksRan).
           append(" Currently busy? ").
-          append(jvmIdToRunner.get(jvmId).busy).
+          append(jvmIdToRunner.get(jvmId).isBusy()).
           append(" Currently running: "). 
           append(jvmToRunningTask.get(jvmId).getTask().getTaskID().toString());
       }
       return details.toString();
     }
 
-    private void spawnNewJvm(JobID jobId, JvmEnv env,  
+    private synchronized void spawnNewJvm(JobID jobId, JvmEnv env,  
         TaskRunner t) {
       JvmRunner jvmRunner = new JvmRunner(env,jobId);
       jvmIdToRunner.put(jvmRunner.jvmId, jvmRunner);
@@ -433,7 +443,6 @@
       volatile int numTasksRan;
       final int numTasksToRun;
       JVMId jvmId;
-      volatile boolean busy = true;
       private ShellCommandExecutor shexec; // shell terminal for running the task
       //context used for starting JVM
       private TaskControllerContext initalContext;
@@ -442,6 +451,11 @@
         this.env = env;
         this.jvmId = new JVMId(jobId, isMap, rand.nextInt());
         this.numTasksToRun = env.conf.getNumTasksToExecutePerJvm();
+
+        this.initalContext = new TaskControllerContext();
+        initalContext.sleeptimeBeforeSigkill = tracker.getJobConf()
+          .getLong(TTConfig.TT_SLEEP_TIME_BEFORE_SIG_KILL,
+                   ProcessTree.DEFAULT_SLEEPTIME_BEFORE_SIGKILL);
         LOG.info("In JvmRunner constructed JVM ID: " + jvmId);
       }
       public void run() {
@@ -449,11 +463,9 @@
       }
 
       public void runChild(JvmEnv env) {
-        initalContext = new TaskControllerContext();
         try {
           env.vargs.add(Integer.toString(jvmId.getId()));
           //Launch the task controller to run task JVM
-          initalContext.task = jvmToRunningTask.get(jvmId).getTask();
           initalContext.env = env;
           tracker.getTaskController().launchTaskJVM(initalContext);
         } catch (IOException ioe) {
@@ -483,6 +495,19 @@
         }
       }
 
+      synchronized void setPid(String pid) {
+        assert initalContext != null;
+        initalContext.pid = pid;
+      }
+
+      synchronized String getPid() {
+        if (initalContext != null) {
+          return initalContext.pid;
+        } else {
+          return null;
+        }
+      }
+
       /** 
        * Kills the process. Also kills its subprocesses if the process(root of subtree
        * of processes) is created using setsid.
@@ -493,11 +518,6 @@
           // Check inital context before issuing a kill to prevent situations
           // where kill is issued before task is launched.
           if (initalContext != null && initalContext.env != null) {
-            initalContext.pid = jvmIdToPid.get(jvmId);
-            initalContext.sleeptimeBeforeSigkill = tracker.getJobConf()
-                .getLong(TTConfig.TT_SLEEP_TIME_BEFORE_SIG_KILL,
-                    ProcessTree.DEFAULT_SLEEPTIME_BEFORE_SIGKILL);
-
             // Destroy the task jvm
             controller.destroyTaskJVM(initalContext);
           } else {
@@ -518,11 +538,6 @@
           // Check inital context before issuing a signal to prevent situations
           // where signal is issued before task is launched.
           if (initalContext != null && initalContext.env != null) {
-            initalContext.pid = jvmIdToPid.get(jvmId);
-            initalContext.sleeptimeBeforeSigkill = tracker.getJobConf()
-                .getLong(TTConfig.TT_SLEEP_TIME_BEFORE_SIG_KILL,
-                    ProcessTree.DEFAULT_SLEEPTIME_BEFORE_SIGKILL);
-
             // signal the task jvm
             controller.dumpTaskStack(initalContext);
 
@@ -539,19 +554,20 @@
         }
       }
 
-      public void taskRan() {
-        busy = false;
+      public synchronized void taskRan() {
+        initalContext.task = null;
         numTasksRan++;
       }
       
       public boolean ranAll() {
         return(numTasksRan == numTasksToRun);
       }
-      public void setBusy(boolean busy) {
-        this.busy = busy;
+      public synchronized void setTaskRunner(TaskRunner runner) {
+        initalContext.task = runner.getTask();
+        assert initalContext.task != null;
       }
-      public boolean isBusy() {
-        return busy;
+      public synchronized boolean isBusy() {
+        return initalContext.task != null;
       }
     }
   }  
diff --git a/src/test/mapred/org/apache/hadoop/mapred/TestJvmManager.java b/src/test/mapred/org/apache/hadoop/mapred/TestJvmManager.java
index fe1ef37..8c47cf3 100644
--- a/src/test/mapred/org/apache/hadoop/mapred/TestJvmManager.java
+++ b/src/test/mapred/org/apache/hadoop/mapred/TestJvmManager.java
@@ -22,8 +22,15 @@
 import java.io.FileOutputStream;
 import java.io.FileReader;
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Vector;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.fs.FileUtil;
 import org.apache.hadoop.mapred.JvmManager.JvmManagerForType;
 import org.apache.hadoop.mapred.JvmManager.JvmManagerForType.JvmRunner;
@@ -36,10 +43,12 @@
 import org.junit.Test;
 
 public class TestJvmManager {
+  static final Log LOG = LogFactory.getLog(TestJvmManager.class);
+
   private static File TEST_DIR = new File(System.getProperty("test.build.data",
       "/tmp"), TestJvmManager.class.getSimpleName());
-  private static int MAP_SLOTS = 1;
-  private static int REDUCE_SLOTS = 1;
+  private static int MAP_SLOTS = 10;
+  private static int REDUCE_SLOTS = 10;
   private TaskTracker tt;
   private JvmManager jvmManager;
   private JobConf ttConf;
@@ -98,7 +107,7 @@
     // launch a jvm
     JobConf taskConf = new JobConf(ttConf);
     TaskAttemptID attemptID = new TaskAttemptID("test", 0, TaskType.MAP, 0, 0);
-    Task task = new MapTask(null, attemptID, 0, null, MAP_SLOTS);
+    Task task = new MapTask(null, attemptID, 0, null, 1);
     task.setConf(taskConf);
     TaskInProgress tip = tt.new TaskInProgress(task, taskConf);
     File pidFile = new File(TEST_DIR, "pid");
@@ -162,7 +171,7 @@
 
     // launch another jvm and see it finishes properly
     attemptID = new TaskAttemptID("test", 0, TaskType.MAP, 0, 1);
-    task = new MapTask(null, attemptID, 0, null, MAP_SLOTS);
+    task = new MapTask(null, attemptID, 0, null, 1);
     task.setConf(taskConf);
     tip = tt.new TaskInProgress(task, taskConf);
     TaskRunner taskRunner2 = task.createRunner(tt, tip);
@@ -180,4 +189,139 @@
     jvmRunner.join();
     launcher.join();
   }
+
+
+  /**
+   * Create a bunch of tasks and use a special hash map to detect
+   * racy access to the various internal data structures of JvmManager.
+   * (Regression test for MAPREDUCE-2224)
+   */
+  @Test
+  public void testForRaces() throws Exception {
+    JvmManagerForType mapJvmManager = jvmManager
+        .getJvmManagerForType(TaskType.MAP);
+
+    // Sub out the HashMaps for maps that will detect racy access.
+    mapJvmManager.jvmToRunningTask = new RaceHashMap<JVMId, TaskRunner>();
+    mapJvmManager.runningTaskToJvm = new RaceHashMap<TaskRunner, JVMId>();
+    mapJvmManager.jvmIdToRunner = new RaceHashMap<JVMId, JvmRunner>();
+
+    // Launch a bunch of JVMs, but only allow MAP_SLOTS to run at once.
+    final ExecutorService exec = Executors.newFixedThreadPool(MAP_SLOTS);
+    final AtomicReference<Throwable> failed =
+      new AtomicReference<Throwable>();
+
+    for (int i = 0; i < MAP_SLOTS*5; i++) {
+      JobConf taskConf = new JobConf(ttConf);
+      TaskAttemptID attemptID = new TaskAttemptID("test", 0, TaskType.MAP, i, 0);
+      Task task = new MapTask(null, attemptID, i, null, 1);
+      task.setConf(taskConf);
+      TaskInProgress tip = tt.new TaskInProgress(task, taskConf);
+      File pidFile = new File(TEST_DIR, "pid_" + i);
+      final TaskRunner taskRunner = task.createRunner(tt, tip);
+      // launch a jvm which sleeps for 60 seconds
+      final Vector<String> vargs = new Vector<String>(2);
+      vargs.add(writeScript("script_" + i, "echo hi\n", pidFile).getAbsolutePath());
+      final File workDir = new File(TEST_DIR, "work_" + i);
+      workDir.mkdir();
+      final File stdout = new File(TEST_DIR, "stdout_" + i);
+      final File stderr = new File(TEST_DIR, "stderr_" + i);
+  
+      // launch the process and wait in a thread, till it finishes
+      Runnable launcher = new Runnable() {
+        public void run() {
+          try {
+            taskRunner.launchJvmAndWait(null, vargs, stdout, stderr, 100,
+                workDir, null);
+          } catch (Throwable t) {
+            failed.compareAndSet(null, t);
+            exec.shutdownNow();
+            return;
+          }
+        }
+      };
+      exec.submit(launcher);
+    }
+
+    exec.shutdown();
+    exec.awaitTermination(3, TimeUnit.MINUTES);
+    if (failed.get() != null) {
+      throw new RuntimeException(failed.get());
+    }
+  }
+
+  /**
+   * HashMap which detects racy usage by sleeping during operations
+   * and checking that no other threads access the map while asleep.
+   */
+  static class RaceHashMap<K,V> extends HashMap<K,V> {
+    Object syncData = new Object();
+    RuntimeException userStack = null;
+    boolean raced = false;
+    
+    private void checkInUse() {
+      synchronized (syncData) {
+        RuntimeException thisStack = new RuntimeException(Thread.currentThread().toString());
+
+        if (userStack != null && raced == false) {
+          RuntimeException other = userStack;
+          raced = true;
+          LOG.fatal("Race between two threads.");
+          LOG.fatal("First", thisStack);
+          LOG.fatal("Second", other);
+          throw new RuntimeException("Raced");
+        } else {
+          userStack = thisStack;
+        }
+      }
+    }
+
+    private void sleepABit() {
+      try {
+        Thread.sleep(60);
+      } catch (InterruptedException ie) {
+        Thread.currentThread().interrupt();
+      }
+    }
+
+    private void done() {
+      synchronized (syncData) {
+        userStack = null;
+      }
+    }
+
+    @Override
+    public V get(Object key) {
+      checkInUse();
+      try {
+        sleepABit();
+        return super.get(key);
+      } finally {
+        done();
+      }
+    }
+
+    @Override
+    public boolean containsKey(Object key) {
+      checkInUse();
+      try {
+        sleepABit();
+        return super.containsKey(key);
+      } finally {
+        done();
+      }
+    }
+    
+    @Override
+    public V put(K key, V val) {
+      checkInUse();
+      try {
+        sleepABit();
+        return super.put(key, val);
+      } finally {
+        done();
+      }
+    }
+  }
+
 }