[AIRFLOW-6591] Add cli option to stop celery worker (#7206)

Now users can gracefully stop the Celery worker by sending a
SIGTERM signal to Celery main process.
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index 26dc25d..1672f75 100644
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -1023,6 +1023,12 @@
                         'flower_hostname', 'flower_port', 'flower_conf', 'flower_url_prefix',
                         'flower_basic_auth', 'broker_api', 'pid', 'daemon', 'stdout', 'stderr', 'log_file'),
                 },
+                {
+                    'name': 'stop',
+                    'func': lazy_load_command('airflow.cli.commands.celery_command.stop_worker'),
+                    'help': "Stop the Celery worker gracefully",
+                    'args': (),
+                }
             )
         })
     subparsers_dict = {sp.get('name') or sp['func'].__name__: sp for sp in subparsers}  # type: ignore
diff --git a/airflow/cli/commands/celery_command.py b/airflow/cli/commands/celery_command.py
index 5e8dd2e..e9f7752 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -23,8 +23,10 @@
 from typing import Optional
 
 import daemon
+import psutil
 from celery.bin import worker as worker_bin
 from daemon.pidfile import TimeoutPIDLockFile
+from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile
 
 from airflow import settings
 from airflow.configuration import conf
@@ -33,6 +35,8 @@
 from airflow.utils.cli import setup_locations, setup_logging, sigint_handler
 from airflow.utils.serve_logs import serve_logs
 
+WORKER_PROCESS_NAME = "worker"
+
 
 @cli_utils.action_logging
 def flower(args):
@@ -93,9 +97,6 @@
 @cli_utils.action_logging
 def worker(args):
     """Starts Airflow Celery worker"""
-    env = os.environ.copy()
-    env['AIRFLOW_HOME'] = settings.AIRFLOW_HOME
-
     if not settings.validate_session():
         print("Worker exiting... database connection precheck failed! ")
         sys.exit(1)
@@ -106,6 +107,16 @@
     if autoscale is None and conf.has_option("celery", "worker_autoscale"):
         autoscale = conf.get("celery", "worker_autoscale")
 
+    # Setup locations
+    pid_file_path, stdout, stderr, log_file = setup_locations(
+        process=WORKER_PROCESS_NAME,
+        pid=args.pid,
+        stdout=args.stdout,
+        stderr=args.stderr,
+        log=args.log_file,
+    )
+
+    # Setup Celery worker
     worker_instance = worker_bin.worker(app=celery_app)
     options = {
         'optimization': 'fair',
@@ -115,23 +126,19 @@
         'autoscale': autoscale,
         'hostname': args.celery_hostname,
         'loglevel': conf.get('logging', 'LOGGING_LEVEL'),
+        'pidfile': pid_file_path,
     }
 
     if conf.has_option("celery", "pool"):
         options["pool"] = conf.get("celery", "pool")
 
     if args.daemon:
-        pid, stdout, stderr, log_file = setup_locations("worker",
-                                                        args.pid,
-                                                        args.stdout,
-                                                        args.stderr,
-                                                        args.log_file)
+        # Run Celery worker as daemon
         handle = setup_logging(log_file)
         stdout = open(stdout, 'w+')
         stderr = open(stderr, 'w+')
 
         ctx = daemon.DaemonContext(
-            pidfile=TimeoutPIDLockFile(pid, -1),
             files_preserve=[handle],
             stdout=stdout,
             stderr=stderr,
@@ -143,11 +150,25 @@
         stdout.close()
         stderr.close()
     else:
-        signal.signal(signal.SIGINT, sigint_handler)
-        signal.signal(signal.SIGTERM, sigint_handler)
-
+        # Run Celery worker in the same process
         sub_proc = _serve_logs(skip_serve_logs)
         worker_instance.run(**options)
 
     if sub_proc:
         sub_proc.terminate()
+
+
+@cli_utils.action_logging
+def stop_worker(args):  # pylint: disable=unused-argument
+    """Sends SIGTERM to Celery worker"""
+    # Read PID from file
+    pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME)
+    pid = read_pid_from_pidfile(pid_file_path)
+
+    # Send SIGTERM
+    if pid:
+        worker_process = psutil.Process(pid)
+        worker_process.terminate()
+
+    # Remove pid file
+    remove_existing_pidfile(pid_file_path)
diff --git a/docs/executor/celery.rst b/docs/executor/celery.rst
index bf86f19..2029130 100644
--- a/docs/executor/celery.rst
+++ b/docs/executor/celery.rst
@@ -53,11 +53,23 @@
     airflow celery worker
 
 Your worker should start picking up tasks as soon as they get fired in
-its direction.
+its direction. To stop a worker running on a machine you can use:
 
-Note that you can also run "Celery Flower", a web UI built on top of Celery,
-to monitor your workers. You can use the shortcut command ``airflow celery flower``
-to start a Flower web server.
+.. code-block:: bash
+
+    airflow celery stop
+
+It will try to stop the worker gracefully by sending ``SIGTERM`` signal to main Celery
+process as recommended by
+`Celery documentation <https://docs.celeryproject.org/en/latest/userguide/workers>`__.
+
+Note that you can also run `Celery Flower <https://flower.readthedocs.io/en/latest/>`__,
+a web UI built on top of Celery, to monitor your workers. You can use the shortcut command
+to start a Flower web server:
+
+.. code-block:: bash
+
+    airflow celery stop
 
 Please note that you must have the ``flower`` python library already installed on your system. The recommend way is to install the airflow celery bundle.
 
diff --git a/setup.py b/setup.py
index aef0b6f..b961632 100644
--- a/setup.py
+++ b/setup.py
@@ -472,6 +472,7 @@
             'json-merge-patch==0.2',
             'jsonschema~=3.0',
             'lazy_object_proxy~=1.3',
+            'lockfile>=0.12.2',
             'markdown>=2.5.2, <3.0',
             'pandas>=0.17.1, <1.0.0',
             'pendulum==1.4.4',
diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py
index 6bb4073..ee78fc2 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -18,6 +18,7 @@
 import importlib
 import unittest
 from argparse import Namespace
+from tempfile import NamedTemporaryFile
 
 import mock
 import pytest
@@ -93,3 +94,56 @@
                 mock_privil.return_value = 0
                 celery_command.worker(args)
                 mock_popen.assert_not_called()
+
+
+class TestCeleryStopCommand(unittest.TestCase):
+    @classmethod
+    @conf_vars({("core", "executor"): "CeleryExecutor"})
+    def setUpClass(cls):
+        importlib.reload(cli)
+        cls.parser = cli.CLIFactory.get_parser()
+
+    @mock.patch("airflow.cli.commands.celery_command.setup_locations")
+    @mock.patch("airflow.cli.commands.celery_command.psutil.Process")
+    def test_if_right_pid_is_read(self, mock_process, mock_setup_locations):
+        args = self.parser.parse_args(['celery', 'stop'])
+        pid = "123"
+
+        # Calling stop_worker should delete the temporary pid file
+        with self.assertRaises(FileNotFoundError):
+            with NamedTemporaryFile("w+") as f:
+                # Create pid file
+                f.write(pid)
+                f.flush()
+                # Setup mock
+                mock_setup_locations.return_value = (f.name, None, None, None)
+                # Check if works as expected
+                celery_command.stop_worker(args)
+                mock_process.assert_called_once_with(int(pid))
+                mock_process.return_value.terminate.assert_called_once_with()
+
+    @mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
+    @mock.patch("airflow.cli.commands.celery_command.worker_bin.worker")
+    @mock.patch("airflow.cli.commands.celery_command.setup_locations")
+    def test_same_pid_file_is_used_in_start_and_stop(
+        self,
+        mock_setup_locations,
+        mock_celery_worker,
+        mock_read_pid_from_pidfile
+    ):
+        pid_file = "test_pid_file"
+        mock_setup_locations.return_value = (pid_file, None, None, None)
+        mock_read_pid_from_pidfile.return_value = None
+
+        # Call worker
+        worker_args = self.parser.parse_args(['celery', 'worker', '-s'])
+        celery_command.worker(worker_args)
+        run_mock = mock_celery_worker.return_value.run
+        assert run_mock.call_args
+        assert 'pidfile' in run_mock.call_args.kwargs
+        assert run_mock.call_args.kwargs['pidfile'] == pid_file
+
+        # Call stop
+        stop_args = self.parser.parse_args(['celery', 'stop'])
+        celery_command.stop_worker(stop_args)
+        mock_read_pid_from_pidfile.assert_called_once_with(pid_file)