Merge pull request #62 from xujyan/jyx/lost_task

Fix a bug that causes task status updates not correctly sent when the task is killed.
diff --git a/mysos/executor/executor.py b/mysos/executor/executor.py
index 122bc36..c3d17ff 100644
--- a/mysos/executor/executor.py
+++ b/mysos/executor/executor.py
@@ -1,4 +1,6 @@
 import json
+import sys
+from threading import Event
 import traceback
 
 from mysos.common.decorators import logged
@@ -9,6 +11,7 @@
 import mesos.interface.mesos_pb2 as mesos_pb2
 from twitter.common import log
 from twitter.common.concurrent import defer
+from twitter.common.quantity import Amount, Time
 
 
 class MysosExecutor(Executor):
@@ -16,6 +19,8 @@
     MysosExecutor is a fine-grained executor, i.e., one executor executes a single task.
   """
 
+  STOP_WAIT = Amount(5, Time.SECONDS)
+
   def __init__(self, runner_provider, sandbox):
     """
       :param runner_provider: An implementation of TaskRunnerProvider.
@@ -27,6 +32,8 @@
     self._killed = False  # True if the executor's singleton task is killed by the scheduler.
     self._sandbox = sandbox
 
+    self._terminated = Event()  # Set when the runner has terminated.
+
   # --- Mesos methods. ---
   @logged
   def registered(self, driver, executorInfo, frameworkInfo, slaveInfo):
@@ -81,21 +88,24 @@
       log.error(traceback.format_exc())
       # Send TASK_LOST for unknown errors.
       self._send_update(task.task_id.value, mesos_pb2.TASK_LOST)
-
-    # Wait for the task's return code (when it terminates).
-    try:
-      returncode = self._runner.join()
-      # Regardless of the return code, if '_runner' terminates, it failed!
-      log.error("Task process terminated with return code %s" % returncode)
-    except TaskError as e:
-      log.error("Task terminated: %s" % e)
-
-    if self._killed:
-      self._send_update(task.task_id.value, mesos_pb2.TASK_KILLED)
     else:
-      self._send_update(task.task_id.value, mesos_pb2.TASK_FAILED)
-
-    self._kill()
+      # Wait for the task's return code (when it terminates).
+      try:
+        returncode = self._runner.join()
+        # If '_runner' terminates, it has either failed or been killed.
+        log.warn("Task process terminated with return code %s" % returncode)
+      except TaskError as e:
+        log.error("Task terminated: %s" % e)
+      finally:
+        if self._killed:
+          self._send_update(task.task_id.value, mesos_pb2.TASK_KILLED)
+        else:
+          self._send_update(task.task_id.value, mesos_pb2.TASK_FAILED)
+        self._terminated.set()
+    finally:
+      # No matter what happens above, when we reach here the executor has no task to run so it
+      # should just commit seppuku.
+      self._kill()
 
   @logged
   def frameworkMessage(self, driver, message):
@@ -118,37 +128,37 @@
           'epoch': master_epoch,  # Send the epoch back without parsing it.
           'position': position
       }))
-    except TaskError as e:
-      # Log the error and do not reply to the framework.
+    except Exception as e:
       log.error("Committing suicide due to failure to process framework message: %s" % e)
+      log.error(traceback.format_exc())
       self._kill()
 
   @logged
   def killTask(self, driver, taskId):
     # Killing the task also kills the executor because there is one task per executor.
     log.info("Asked to kill task %s" % taskId.value)
-    self._killed = True
+
     self._kill()
 
   def _kill(self):
     if self._runner:
+      self._killed = True
       self._runner.stop()  # It could be already stopped. If so, self._runner.stop() is a no-op.
+      self._terminated.wait(sys.maxint)
 
     assert self._driver
 
     # TODO(jyx): Fix https://issues.apache.org/jira/browse/MESOS-243.
-    self._driver.stop()
+    defer(lambda: self._driver.stop(), delay=self.STOP_WAIT)
 
   @logged
   def shutdown(self, driver):
     log.info("Asked to shut down")
-    self._killed = True
     self._kill()
 
   @logged
   def error(self, driver, message):
     log.error("Shutting down due to error: %s" % message)
-    self._killed = True
     self._kill()
 
   def _send_update(self, task_id, state, message=None):
diff --git a/mysos/executor/mysos_task_runner.py b/mysos/executor/mysos_task_runner.py
index 0051bcb..eb75076 100644
--- a/mysos/executor/mysos_task_runner.py
+++ b/mysos/executor/mysos_task_runner.py
@@ -148,6 +148,15 @@
     self._exited.set()
 
   def stop(self, timeout=10):
+    with self._lock:
+      # stop() could be called by multiple threads. Locking so we only stop the runner once.
+      if self._stopping:
+        log.warn("The runner is already stopping/stopped")
+        return False
+      else:
+        log.info("Stopping runner")
+        self._stopping = True
+
     try:
       return self._stop(timeout)
     finally:
@@ -163,17 +172,10 @@
       :return: True if an active runner is stopped, False if the runner is not started or already
                stopping/stopped.
     """
-    if not self._started:
-      log.warn("Cannot stop the runner because it's not started")
-      return False
-
-    if self._stopping:
-      log.warn("The runner is already stopping/stopped")
-      return False
-
     with self._lock:
-      log.info("Stopping runner")
-      self._stopping = True
+      if not self._started:
+        log.warn("Cannot stop the runner because it's not started")
+        return False
 
       if not self._popen:
         log.info("The runner task did not start successfully so no need to kill it")
@@ -198,6 +200,8 @@
         except OSError as e:
           log.info("The sub-processes are already terminated: %s" % e)
           return False
+    else:
+      return True
 
     log.info("Waiting for process to terminate due to SIGKILL")
     if not self._exited.wait(timeout=timeout):
diff --git a/mysos/executor/mysql_task_control.py b/mysos/executor/mysql_task_control.py
index d5ca96f..40504c9 100644
--- a/mysos/executor/mysql_task_control.py
+++ b/mysos/executor/mysql_task_control.py
@@ -150,7 +150,7 @@
             conf_file=self._conf_file,
             buffer_pool_size=self._buffer_pool_size))
     log.info("Executing command: %s" % command)
-    self._process = subprocess.Popen(command, shell=True, env=env)
+    self._process = subprocess.Popen(command, shell=True, env=env, preexec_fn=os.setpgrp)
 
     # There is a delay before mysqld becomes available to accept requests. Wait for it.
     command = "%(cmd)s %(pid_file)s %(port)s %(timeout)s" % dict(
diff --git a/tests/scheduler/test_mysos_scheduler.py b/tests/scheduler/test_mysos_scheduler.py
index 3aeead1..ffb318b 100644
--- a/tests/scheduler/test_mysos_scheduler.py
+++ b/tests/scheduler/test_mysos_scheduler.py
@@ -12,6 +12,7 @@
 from twitter.common import log
 from twitter.common.concurrent import deadline
 from twitter.common.dirutil import safe_mkdtemp
+from twitter.common.metrics import RootMetrics
 from twitter.common.quantity import Amount, Time
 from zake.fake_client import FakeClient
 from zake.fake_storage import FakeStorage
@@ -64,6 +65,8 @@
       "/fakepath",
       gen_encryption_key())
 
+  RootMetrics().register_observable('scheduler', scheduler)
+
   scheduler_driver = mesos.native.MesosSchedulerDriver(
       scheduler,
       framework_info,
@@ -84,11 +87,14 @@
 
   scheduler.delete_cluster(cluster_name, password="passwd")
 
-  # A slave is promoted to be the master.
+  # The cluster is deleted from ZooKeeper.
   deadline(
       lambda: wait_for_termination(
           get_cluster_path(posixpath.join(zk_url, 'discover'), cluster_name),
           zk_client),
       Amount(40, Time.SECONDS))
 
+  sample = RootMetrics().sample()
+  assert sample['scheduler.tasks_killed'] == 1
+
   assert scheduler_driver.stop() == DRIVER_STOPPED