blob: ac591248ae49be9697f6879500c6a456d0caf955 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import multiprocessing
import os
import signal
import subprocess
import time
from contextlib import suppress
from subprocess import CalledProcessError
from time import sleep
from unittest import mock
import psutil
import pytest
from airflow.exceptions import AirflowException
from airflow.utils import process_utils
from airflow.utils.process_utils import (
check_if_pidfile_process_is_running,
execute_in_subprocess,
execute_in_subprocess_with_kwargs,
set_new_process_group,
)
class TestReapProcessGroup:
@staticmethod
def _ignores_sigterm(child_pid, child_setup_done):
def signal_handler(unused_signum, unused_frame):
pass
signal.signal(signal.SIGTERM, signal_handler)
child_pid.value = os.getpid()
child_setup_done.release()
while True:
time.sleep(1)
@staticmethod
def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done):
def signal_handler(unused_signum, unused_frame):
pass
os.setsid()
signal.signal(signal.SIGTERM, signal_handler)
child_setup_done = multiprocessing.Semaphore(0)
child = multiprocessing.Process(
target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done]
)
child.start()
child_setup_done.acquire(timeout=5.0)
parent_pid.value = os.getpid()
setup_done.release()
while True:
time.sleep(1)
def test_reap_process_group(self):
"""
Spin up a process that can't be killed by SIGTERM and make sure
it gets killed anyway.
"""
parent_setup_done = multiprocessing.Semaphore(0)
parent_pid = multiprocessing.Value("i", 0)
child_pid = multiprocessing.Value("i", 0)
args = [parent_pid, child_pid, parent_setup_done]
parent = multiprocessing.Process(target=TestReapProcessGroup._parent_of_ignores_sigterm, args=args)
try:
parent.start()
assert parent_setup_done.acquire(timeout=5.0)
assert psutil.pid_exists(parent_pid.value)
assert psutil.pid_exists(child_pid.value)
process_utils.reap_process_group(parent_pid.value, logging.getLogger(), timeout=1)
assert not psutil.pid_exists(parent_pid.value)
assert not psutil.pid_exists(child_pid.value)
finally:
try:
os.kill(parent_pid.value, signal.SIGKILL) # terminate doesn't work here
os.kill(child_pid.value, signal.SIGKILL) # terminate doesn't work here
except OSError:
pass
@pytest.mark.db_test
class TestExecuteInSubProcess:
def test_should_print_all_messages1(self, caplog):
execute_in_subprocess(["bash", "-c", "echo CAT; echo KITTY;"])
msgs = [record.getMessage() for record in caplog.records]
assert ["Executing cmd: bash -c 'echo CAT; echo KITTY;'", "Output:", "CAT", "KITTY"] == msgs
def test_should_print_all_messages_from_cwd(self, caplog, tmp_path):
execute_in_subprocess(["bash", "-c", "echo CAT; pwd; echo KITTY;"], cwd=str(tmp_path))
msgs = [record.getMessage() for record in caplog.records]
assert [
"Executing cmd: bash -c 'echo CAT; pwd; echo KITTY;'",
"Output:",
"CAT",
str(tmp_path),
"KITTY",
] == msgs
def test_should_raise_exception(self):
with pytest.raises(CalledProcessError):
process_utils.execute_in_subprocess(["bash", "-c", "exit 1"])
def test_using_env_as_kwarg_works(self, caplog):
execute_in_subprocess_with_kwargs(["bash", "-c", 'echo "My value is ${VALUE}"'], env=dict(VALUE="1"))
assert "My value is 1" in caplog.text
def my_sleep_subprocess():
sleep(100)
def my_sleep_subprocess_with_signals():
signal.signal(signal.SIGINT, lambda signum, frame: None)
signal.signal(signal.SIGTERM, lambda signum, frame: None)
sleep(100)
@pytest.mark.db_test
class TestKillChildProcessesByPids:
def test_should_kill_process(self):
before_num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
process = multiprocessing.Process(target=my_sleep_subprocess, args=())
process.start()
sleep(0)
num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
assert before_num_process + 1 == num_process
process_utils.kill_child_processes_by_pids([process.pid])
num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
assert before_num_process == num_process
def test_should_force_kill_process(self, caplog):
process = multiprocessing.Process(target=my_sleep_subprocess_with_signals, args=())
process.start()
sleep(0)
all_processes = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().splitlines()
assert str(process.pid) in (x.strip() for x in all_processes)
with caplog.at_level(logging.INFO, logger=process_utils.log.name):
caplog.clear()
process_utils.kill_child_processes_by_pids([process.pid], timeout=0)
assert f"Killing child PID: {process.pid}" in caplog.messages
sleep(0)
all_processes = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().splitlines()
assert str(process.pid) not in (x.strip() for x in all_processes)
class TestPatchEnviron:
def test_should_update_variable_and_restore_state_when_exit(self):
with mock.patch.dict("os.environ", {"TEST_NOT_EXISTS": "BEFORE", "TEST_EXISTS": "BEFORE"}):
del os.environ["TEST_NOT_EXISTS"]
assert "BEFORE" == os.environ["TEST_EXISTS"]
assert "TEST_NOT_EXISTS" not in os.environ
with process_utils.patch_environ({"TEST_NOT_EXISTS": "AFTER", "TEST_EXISTS": "AFTER"}):
assert "AFTER" == os.environ["TEST_NOT_EXISTS"]
assert "AFTER" == os.environ["TEST_EXISTS"]
assert "BEFORE" == os.environ["TEST_EXISTS"]
assert "TEST_NOT_EXISTS" not in os.environ
def test_should_restore_state_when_exception(self):
with mock.patch.dict("os.environ", {"TEST_NOT_EXISTS": "BEFORE", "TEST_EXISTS": "BEFORE"}):
del os.environ["TEST_NOT_EXISTS"]
assert "BEFORE" == os.environ["TEST_EXISTS"]
assert "TEST_NOT_EXISTS" not in os.environ
with suppress(AirflowException):
with process_utils.patch_environ({"TEST_NOT_EXISTS": "AFTER", "TEST_EXISTS": "AFTER"}):
assert "AFTER" == os.environ["TEST_NOT_EXISTS"]
assert "AFTER" == os.environ["TEST_EXISTS"]
raise AirflowException("Unknown exception")
assert "BEFORE" == os.environ["TEST_EXISTS"]
assert "TEST_NOT_EXISTS" not in os.environ
class TestCheckIfPidfileProcessIsRunning:
def test_ok_if_no_file(self):
check_if_pidfile_process_is_running("some/pid/file", process_name="test")
def test_remove_if_no_process(self, tmp_path):
path = tmp_path / "testfile"
# limit pid as max of int32, otherwise this test could fail on some platform
path.write_text(f"{2**31 - 1}")
check_if_pidfile_process_is_running(os.fspath(path), process_name="test")
# Assert file is deleted
assert not path.exists()
def test_raise_error_if_process_is_running(self, tmp_path):
path = tmp_path / "testfile"
pid = os.getpid()
path.write_text(f"{pid}")
with pytest.raises(AirflowException, match="is already running under PID"):
check_if_pidfile_process_is_running(os.fspath(path), process_name="test")
class TestSetNewProcessGroup:
@mock.patch("os.setpgid")
def test_not_session_leader(self, mock_set_pid):
pid = os.getpid()
with mock.patch("os.getsid", autospec=True) as mock_get_sid:
mock_get_sid.return_value = pid + 1
set_new_process_group()
assert mock_set_pid.call_count == 1
@mock.patch("os.setpgid")
def test_session_leader(self, mock_set_pid):
pid = os.getpid()
with mock.patch("os.getsid", autospec=True) as mock_get_sid:
mock_get_sid.return_value = pid
set_new_process_group()
assert mock_set_pid.call_count == 0