| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from __future__ import annotations |
| |
| import inspect |
| import json |
| import logging |
| import os |
| import re |
| import selectors |
| import signal |
| import socket |
| import subprocess |
| import sys |
| import time |
| from contextlib import nullcontext |
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from operator import attrgetter |
| from random import randint |
| from textwrap import dedent |
| from time import sleep |
| from typing import TYPE_CHECKING, Any |
| from unittest import mock |
| from unittest.mock import MagicMock, patch |
| |
| import httpx |
| import msgspec |
| import psutil |
| import pytest |
| import structlog |
| from pytest_unordered import unordered |
| from task_sdk import FAKE_BUNDLE, make_client |
| from uuid6 import uuid7 |
| |
| from airflow.executors.workloads import BundleInfo |
| from airflow.sdk import BaseOperator, timezone |
| from airflow.sdk.api import client as sdk_client |
| from airflow.sdk.api.client import ServerResponseError |
| from airflow.sdk.api.datamodels._generated import ( |
| AssetEventResponse, |
| AssetProfile, |
| AssetResponse, |
| DagRun, |
| DagRunState, |
| DagRunType, |
| PreviousTIResponse, |
| TaskInstance, |
| TaskInstanceState, |
| ) |
| from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType |
| from airflow.sdk.execution_time import task_runner |
| from airflow.sdk.execution_time.comms import ( |
| AssetEventsResult, |
| AssetResult, |
| CommsDecoder, |
| ConnectionResult, |
| CreateHITLDetailPayload, |
| DagRunResult, |
| DagRunStateResult, |
| DeferTask, |
| DeleteVariable, |
| DeleteXCom, |
| DRCount, |
| ErrorResponse, |
| GetAssetByName, |
| GetAssetByUri, |
| GetAssetEventByAsset, |
| GetAssetEventByAssetAlias, |
| GetConnection, |
| GetDagRun, |
| GetDagRunState, |
| GetDRCount, |
| GetHITLDetailResponse, |
| GetPreviousDagRun, |
| GetPreviousTI, |
| GetPrevSuccessfulDagRun, |
| GetTaskBreadcrumbs, |
| GetTaskRescheduleStartDate, |
| GetTaskStates, |
| GetTICount, |
| GetVariable, |
| GetXCom, |
| GetXComCount, |
| GetXComSequenceItem, |
| GetXComSequenceSlice, |
| HITLDetailRequestResult, |
| InactiveAssetsResult, |
| MaskSecret, |
| OKResponse, |
| PreviousDagRunResult, |
| PreviousTIResult, |
| PrevSuccessfulDagRunResult, |
| PutVariable, |
| RescheduleTask, |
| ResendLoggingFD, |
| RetryTask, |
| SentFDs, |
| SetRenderedFields, |
| SetRenderedMapIndex, |
| SetXCom, |
| SkipDownstreamTasks, |
| SucceedTask, |
| TaskBreadcrumbsResult, |
| TaskRescheduleStartDate, |
| TaskState, |
| TaskStatesResult, |
| TICount, |
| ToSupervisor, |
| TriggerDagRun, |
| UpdateHITLDetail, |
| ValidateInletsAndOutlets, |
| VariableResult, |
| XComCountResponse, |
| XComResult, |
| XComSequenceIndexResult, |
| XComSequenceSliceResult, |
| _RequestFrame, |
| _ResponseFrame, |
| ) |
| from airflow.sdk.execution_time.supervisor import ( |
| ActivitySubprocess, |
| InProcessSupervisorComms, |
| InProcessTestSupervisor, |
| _make_process_nondumpable, |
| _remote_logging_conn, |
| process_log_messages_from_subprocess, |
| set_supervisor_comms, |
| supervise, |
| ) |
| from airflow.sdk.execution_time.task_runner import run |
| |
| from tests_common.test_utils.config import conf_vars |
| |
| if TYPE_CHECKING: |
| import kgb |
| |
| log = logging.getLogger(__name__) |
| TI_ID = uuid7() |
| |
| |
| def lineno(): |
| """Returns the current line number in our program.""" |
| return inspect.currentframe().f_back.f_lineno |
| |
| |
| def local_dag_bundle_cfg(path, name="my-bundle"): |
| return { |
| "AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps( |
| [ |
| { |
| "name": name, |
| "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", |
| "kwargs": {"path": str(path), "refresh_interval": 1}, |
| } |
| ] |
| ) |
| } |
| |
| |
| @pytest.fixture |
| def client_with_ti_start(make_ti_context): |
| client = MagicMock(spec=sdk_client.Client) |
| client.task_instances.start.return_value = make_ti_context() |
| return client |
| |
| |
| @pytest.mark.usefixtures("disable_capturing") |
| class TestSupervisor: |
| @pytest.mark.parametrize( |
| ("server", "dry_run", "expectation"), |
| [ |
| ("/execution/", False, pytest.raises(ValueError, match="Invalid execution API server URL")), |
| ("", False, pytest.raises(ValueError, match="Invalid execution API server URL")), |
| ("http://localhost:8080", True, pytest.raises(ValueError, match="Can only specify one of")), |
| (None, True, nullcontext()), |
| ("http://localhost:8080/execution/", False, nullcontext()), |
| ("https://localhost:8080/execution/", False, nullcontext()), |
| ], |
| ) |
| def test_supervise( |
| self, |
| server, |
| dry_run, |
| expectation, |
| test_dags_dir, |
| client_with_ti_start, |
| ): |
| """ |
| Test that the supervisor validates server URL and dry_run parameter combinations correctly. |
| """ |
| ti = TaskInstance( |
| id=uuid7(), |
| task_id="async", |
| dag_id="super_basic_deferred_run", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ) |
| |
| bundle_info = BundleInfo(name="my-bundle", version=None) |
| |
| kw = { |
| "ti": ti, |
| "dag_rel_path": "super_basic_deferred_run.py", |
| "token": "", |
| "bundle_info": bundle_info, |
| "dry_run": dry_run, |
| "server": server, |
| } |
| if isinstance(expectation, nullcontext): |
| kw["client"] = client_with_ti_start |
| |
| with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): |
| with expectation: |
| supervise(**kw) |
| |
| |
| @pytest.mark.usefixtures("disable_capturing") |
| class TestWatchedSubprocess: |
| @pytest.fixture(autouse=True) |
| def disable_log_upload(self, spy_agency): |
| spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) |
| |
| @pytest.fixture(autouse=True) |
| def use_real_secrets_backends(self, monkeypatch): |
| """ |
| Ensure that real secrets backend instances are used instead of mocks. |
| |
| This prevents Python 3.13 RuntimeWarning when hasattr checks async methods |
| on mocked backends. The warning occurs because hasattr on AsyncMock creates |
| unawaited coroutines. |
| |
| This fixture ensures test isolation when running in parallel with pytest-xdist, |
| regardless of what other tests patch. |
| """ |
| import importlib |
| |
| import airflow.sdk.execution_time.secrets.execution_api as execution_api_module |
| from airflow.secrets.environment_variables import EnvironmentVariablesBackend |
| |
| fresh_execution_backend = importlib.reload(execution_api_module).ExecutionAPISecretsBackend |
| |
| # Ensure downstream imports see the restored class instead of any AsyncMock left by other tests |
| import airflow.sdk.execution_time.secrets as secrets_package |
| |
| monkeypatch.setattr(secrets_package, "ExecutionAPISecretsBackend", fresh_execution_backend) |
| |
| monkeypatch.setattr( |
| "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", |
| lambda: [EnvironmentVariablesBackend(), fresh_execution_backend()], |
| ) |
| |
| def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start): |
| def subprocess_main(): |
| # This is run in the subprocess! |
| |
| # Ensure we follow the "protocol" and get the startup message before we do anything else |
| CommsDecoder()._get_response() |
| |
| import logging |
| import warnings |
| |
| print("I'm a short message") |
| sys.stdout.write("Message ") |
| print("stderr message", file=sys.stderr) |
| # We need a short sleep for the main process to process things. I worry this timing will be |
| # fragile, but I can't think of a better way. This lets the stdout be read (partial line) and the |
| # stderr full line be read |
| sleep(0.1) |
| sys.stdout.write("split across two writes\n") |
| |
| logging.getLogger("airflow.foobar").error("An error message") |
| |
| warnings.warn("Warning should be appear from the correct callsite", stacklevel=1) |
| |
| line = lineno() - 2 # Line the error should be on |
| |
| instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 78901) |
| time_machine.move_to(instant, tick=False) |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id="4d828a62-a417-4936-a7a6-2b3fabacecab", |
| task_id="b", |
| dag_id="c", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ), |
| client=client_with_ti_start, |
| target=subprocess_main, |
| ) |
| |
| rc = proc.wait() |
| |
| assert rc == 0 |
| assert captured_logs == unordered( |
| [ |
| { |
| "logger": "task.stdout", |
| "event": "I'm a short message", |
| "level": "info", |
| "timestamp": "2024-11-07T12:34:56.078901Z", |
| }, |
| { |
| "logger": "task.stderr", |
| "event": "stderr message", |
| "level": "error", |
| "timestamp": "2024-11-07T12:34:56.078901Z", |
| }, |
| { |
| "logger": "task.stdout", |
| "event": "Message split across two writes", |
| "level": "info", |
| "timestamp": "2024-11-07T12:34:56.078901Z", |
| }, |
| { |
| "event": "An error message", |
| "level": "error", |
| "logger": "airflow.foobar", |
| "timestamp": instant, |
| "loc": mock.ANY, |
| }, |
| { |
| "category": "UserWarning", |
| "event": "Warning should be appear from the correct callsite", |
| "filename": __file__, |
| "level": "warning", |
| "lineno": line, |
| "logger": "py.warnings", |
| "timestamp": instant, |
| }, |
| ] |
| ) |
| |
| @pytest.mark.flaky(reruns=3) |
| def test_reopen_log_fd(self, captured_logs, client_with_ti_start): |
| def subprocess_main(): |
| # This is run in the subprocess! |
| |
| # Ensure we follow the "protocol" and get the startup message before we do anything else |
| comms = CommsDecoder() |
| comms._get_response() |
| |
| logs = comms.send(ResendLoggingFD()) |
| assert isinstance(logs, SentFDs) |
| logging.root.info("Log on old socket") |
| with os.fdopen(logs.fds[0], "w") as fd: |
| json.dump({"level": "info", "event": "Log on new socket"}, fp=fd) |
| fd.write("\n") |
| |
| line = lineno() - 5 # Line the error should be on |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id="4d828a62-a417-4936-a7a6-2b3fabacecab", |
| task_id="b", |
| dag_id="c", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ), |
| client=client_with_ti_start, |
| target=subprocess_main, |
| ) |
| |
| rc = proc.wait() |
| |
| assert rc == 0 |
| assert captured_logs == unordered( |
| [ |
| { |
| "event": "Log on new socket", |
| "level": "info", |
| "logger": "task", |
| "timestamp": mock.ANY, |
| # Since this is set as json, without filename or linno, we _should_ not add any. |
| }, |
| { |
| "event": "Log on old socket", |
| "level": "info", |
| "logger": "root", |
| "timestamp": mock.ANY, |
| "loc": f"{os.path.basename(__file__)}:{line}", |
| }, |
| ] |
| ) |
| |
| def test_on_kill_hook_called_when_sigkilled( |
| self, |
| client_with_ti_start, |
| mocked_parse, |
| make_ti_context, |
| mock_supervisor_comms, |
| create_runtime_ti, |
| make_ti_context_dict, |
| capfd, |
| ): |
| main_pid = os.getpid() |
| ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab" |
| |
| def handle_request(request: httpx.Request) -> httpx.Response: |
| if request.url.path == f"/task-instances/{ti_id}/heartbeat": |
| return httpx.Response( |
| status_code=409, |
| json={ |
| "detail": { |
| "reason": "not_running", |
| "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", |
| "current_state": "failed", |
| } |
| }, |
| ) |
| if request.url.path == f"/task-instances/{ti_id}/run": |
| return httpx.Response(200, json=make_ti_context_dict()) |
| return httpx.Response(status_code=204) |
| |
| def subprocess_main(): |
| # Ensure we follow the "protocol" and get the startup message before we do anything |
| CommsDecoder()._get_response() |
| |
| class CustomOperator(BaseOperator): |
| def execute(self, context): |
| for i in range(1000): |
| print(f"Iteration {i}") |
| sleep(1) |
| |
| def on_kill(self) -> None: |
| print("On kill hook called!") |
| |
| task = CustomOperator(task_id="print-params") |
| runtime_ti = create_runtime_ti( |
| dag_id="c", |
| task=task, |
| conf={ |
| "x": 3, |
| "text": "Hello World!", |
| "flag": False, |
| "a_simple_list": ["one", "two", "three", "actually one value is made per line"], |
| }, |
| ) |
| run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) |
| |
| assert os.getpid() != main_pid |
| os.kill(os.getpid(), signal.SIGTERM) |
| # Ensure that the signal is serviced before we finish and exit the subprocess. |
| sleep(0.5) |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id=ti_id, |
| task_id="b", |
| dag_id="c", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ), |
| client=make_client(transport=httpx.MockTransport(handle_request)), |
| target=subprocess_main, |
| ) |
| |
| proc.wait() |
| captured = capfd.readouterr() |
| assert "On kill hook called!" in captured.out |
| |
| def test_subprocess_sigkilled(self, client_with_ti_start): |
| main_pid = os.getpid() |
| |
| def subprocess_main(): |
| # Ensure we follow the "protocol" and get the startup message before we do anything |
| CommsDecoder()._get_response() |
| |
| assert os.getpid() != main_pid |
| os.kill(os.getpid(), signal.SIGKILL) |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id="4d828a62-a417-4936-a7a6-2b3fabacecab", |
| task_id="b", |
| dag_id="c", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ), |
| client=client_with_ti_start, |
| target=subprocess_main, |
| ) |
| |
| rc = proc.wait() |
| |
| assert rc == -9 |
| |
| def test_last_chance_exception_handling(self, capfd): |
| def subprocess_main(): |
| # The real main() in task_runner catches exceptions! This is what would happen if we had a syntax |
| # or import error for instance - a very early exception |
| raise RuntimeError("Fake syntax error") |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() |
| ), |
| client=MagicMock(spec=sdk_client.Client), |
| target=subprocess_main, |
| ) |
| |
| rc = proc.wait() |
| |
| assert rc == 126 |
| |
| captured = capfd.readouterr() |
| assert "Last chance exception handler" in captured.err |
| assert "RuntimeError: Fake syntax error" in captured.err |
| |
| def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, make_ti_context): |
| """Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency""" |
| import airflow.sdk.execution_time.supervisor |
| |
| monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) |
| |
| def subprocess_main(): |
| CommsDecoder()._get_response() |
| |
| for _ in range(5): |
| print("output", flush=True) |
| sleep(0.05) |
| |
| ti_id = uuid7() |
| _ = mocker.patch.object(sdk_client.TaskInstanceOperations, "start", return_value=make_ti_context()) |
| |
| spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() |
| ), |
| client=sdk_client.Client(base_url="", dry_run=True, token=""), |
| target=subprocess_main, |
| ) |
| assert proc.wait() == 0 |
| assert spy.called_with(ti_id, pid=proc.pid) # noqa: PGH005 |
| # The exact number we get will depend on timing behaviour, so be a little lenient |
| assert 1 <= len(spy.calls) <= 4 |
| |
| def test_no_heartbeat_in_overtime(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, make_ti_context): |
| """Test that we don't try and send heartbeats for task that are in "overtime".""" |
| import airflow.sdk.execution_time.supervisor |
| |
| monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) |
| |
| def subprocess_main(): |
| CommsDecoder()._get_response() |
| |
| for _ in range(5): |
| print("output", flush=True) |
| sleep(0.05) |
| |
| ti_id = uuid7() |
| _ = mocker.patch.object(sdk_client.TaskInstanceOperations, "start", return_value=make_ti_context()) |
| |
| @spy_agency.spy_for(ActivitySubprocess._on_child_started) |
| def _on_child_started(self, *args, **kwargs): |
| # Set it up so we are in overtime straight away |
| self._terminal_state = TaskInstanceState.SUCCESS |
| ActivitySubprocess._on_child_started.call_original(self, *args, **kwargs) |
| |
| heartbeat_spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() |
| ), |
| client=sdk_client.Client(base_url="", dry_run=True, token=""), |
| target=subprocess_main, |
| ) |
| assert proc.wait() == 0 |
| spy_agency.assert_spy_not_called(heartbeat_spy) |
| |
| def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, client_with_ti_start): |
| """Test running a simple DAG in a subprocess and capturing the output.""" |
| |
| instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 78901) |
| time_machine.move_to(instant, tick=False) |
| |
| dagfile_path = test_dags_dir |
| ti = TaskInstance( |
| id=uuid7(), |
| task_id="hello", |
| dag_id="super_basic_run", |
| run_id="c", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ) |
| |
| bundle_info = BundleInfo(name="my-bundle", version=None) |
| with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): |
| exit_code = supervise( |
| ti=ti, |
| dag_rel_path=dagfile_path, |
| token="", |
| server="", |
| dry_run=True, |
| client=client_with_ti_start, |
| bundle_info=bundle_info, |
| ) |
| assert exit_code == 0, captured_logs |
| |
| # We should have a log from the task! |
| assert { |
| "logger": "task.stdout", |
| "event": "Hello World hello!", |
| "level": "info", |
| "timestamp": "2024-11-07T12:34:56.078901Z", |
| } in captured_logs |
| |
| def test_supervise_handles_deferred_task( |
| self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context |
| ): |
| """ |
| Test that the supervisor handles a deferred task correctly. |
| |
| This includes ensuring the task starts and executes successfully, and that the task is deferred (via |
| the API client) with the expected parameters. |
| """ |
| instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 0) |
| |
| ti = TaskInstance( |
| id=uuid7(), |
| task_id="async", |
| dag_id="super_basic_deferred_run", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ) |
| |
| # Create a mock client to assert calls to the client |
| # We assume the implementation of the client is correct and only need to check the calls |
| mock_client = mocker.Mock(spec=sdk_client.Client) |
| mock_client.task_instances.start.return_value = make_ti_context() |
| |
| time_machine.move_to(instant, tick=False) |
| current = 1_000_000.0 |
| |
| def mock_monotonic(): |
| return current |
| |
| bundle_info = BundleInfo(name="my-bundle", version=None) |
| with ( |
| patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)), |
| patch("airflow.sdk.execution_time.supervisor.time.monotonic", side_effect=mock_monotonic), |
| ): |
| exit_code = supervise( |
| ti=ti, |
| dag_rel_path="super_basic_deferred_run.py", |
| token="", |
| client=mock_client, |
| bundle_info=bundle_info, |
| ) |
| assert exit_code == 0, captured_logs |
| |
| # Validate calls to the client |
| mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY) |
| mock_client.task_instances.heartbeat.assert_called_once_with(ti.id, pid=mocker.ANY) |
| mock_client.task_instances.defer.assert_called_once_with( |
| ti.id, |
| # Since the message as serialized in the client upon sending, we expect it to be already encoded |
| DeferTask( |
| classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", |
| next_method="execute_complete", |
| trigger_kwargs={ |
| "moment": { |
| "__classname__": "pendulum.datetime.DateTime", |
| "__version__": 2, |
| "__data__": { |
| "timestamp": 1730982899.0, |
| "tz": { |
| "__classname__": "builtins.tuple", |
| "__version__": 1, |
| "__data__": ["UTC", "pendulum.tz.timezone.Timezone", 1, True], |
| }, |
| }, |
| }, |
| "end_from_trigger": False, |
| }, |
| trigger_timeout=None, |
| next_kwargs={}, |
| ), |
| ) |
| |
| # We are asserting the log messages here to ensure the task ran successfully |
| # and mainly to get the final state of the task matches one in the DB. |
| assert { |
| "exit_code": 0, |
| "duration": 0.0, |
| "final_state": "deferred", |
| "event": "Task finished", |
| "timestamp": mocker.ANY, |
| "level": "info", |
| "logger": "supervisor", |
| "loc": mocker.ANY, |
| "task_instance_id": str(ti.id), |
| } in captured_logs |
| |
| def test_supervisor_handles_already_running_task(self): |
| """Test that Supervisor prevents starting a Task Instance that is already running.""" |
| ti = TaskInstance( |
| id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() |
| ) |
| |
| # Mock API Server response indicating the TI is already running |
| # The API Server would return a 409 Conflict status code if the TI is not |
| # in a "queued" state. |
| def handle_request(request: httpx.Request) -> httpx.Response: |
| if request.url.path == f"/task-instances/{ti.id}/run": |
| return httpx.Response( |
| 409, |
| json={ |
| "reason": "invalid_state", |
| "message": "TI was not in a state where it could be marked as running", |
| "previous_state": "running", |
| }, |
| ) |
| |
| return httpx.Response(status_code=204) |
| |
| client = make_client(transport=httpx.MockTransport(handle_request)) |
| |
| with pytest.raises(ServerResponseError, match="Server returned error") as err: |
| ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client) |
| |
| assert err.value.response.status_code == 409 |
| assert err.value.detail == { |
| "reason": "invalid_state", |
| "message": "TI was not in a state where it could be marked as running", |
| "previous_state": "running", |
| } |
| |
| @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"]) |
| def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, make_ti_context_dict): |
| """ |
| Test that ensures that the Supervisor does not cause the task to fail if the Task Instance is no longer |
| in the running state. Instead, it logs the error and terminates the task process if it |
| might be running in a different state or has already completed -- or running on a different worker. |
| |
| Also verifies that the supervisor does not try to send the finish request (update_state) to the API server. |
| """ |
| import airflow.sdk.execution_time.supervisor |
| |
| # Heartbeat every time around the loop |
| monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.0) |
| |
| def subprocess_main(): |
| CommsDecoder()._get_response() |
| sleep(5) |
| # Shouldn't get here |
| exit(5) |
| |
| ti_id = uuid7() |
| |
| # Track the number of requests to simulate mixed responses |
| request_count = {"count": 0} |
| |
| def handle_request(request: httpx.Request) -> httpx.Response: |
| if request.url.path == f"/task-instances/{ti_id}/heartbeat": |
| request_count["count"] += 1 |
| if request_count["count"] == 1: |
| # First request succeeds |
| return httpx.Response(status_code=204) |
| # Second request returns a conflict status code |
| return httpx.Response( |
| 409, |
| json={ |
| "reason": "not_running", |
| "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", |
| "current_state": "success", |
| }, |
| ) |
| if request.url.path == f"/task-instances/{ti_id}/run": |
| return httpx.Response(200, json=make_ti_context_dict()) |
| if request.url.path == f"/task-instances/{ti_id}/state": |
| pytest.fail("Should not have sent a state update request") |
| # Return a 204 for all other requests |
| return httpx.Response(status_code=204) |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| what=TaskInstance( |
| id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() |
| ), |
| client=make_client(transport=httpx.MockTransport(handle_request)), |
| target=subprocess_main, |
| bundle_info=FAKE_BUNDLE, |
| ) |
| |
| # Wait for the subprocess to finish -- it should have been terminated with SIGTERM |
| assert proc.wait() == -signal.SIGTERM |
| assert proc._exit_code == -signal.SIGTERM |
| assert proc.final_state == "SERVER_TERMINATED" |
| |
| assert request_count["count"] == 2 |
| # Verify the error was logged |
| assert captured_logs == [ |
| { |
| "detail": { |
| "reason": "not_running", |
| "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", |
| "current_state": "success", |
| }, |
| "event": "Server indicated the task shouldn't be running anymore", |
| "level": "error", |
| "status_code": 409, |
| "logger": "supervisor", |
| "timestamp": mocker.ANY, |
| "ti_id": ti_id, |
| "loc": mocker.ANY, |
| }, |
| { |
| "detail": { |
| "current_state": "success", |
| "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", |
| "reason": "not_running", |
| }, |
| "event": "Server indicated the task shouldn't be running anymore. Terminating process", |
| "level": "error", |
| "logger": "task", |
| "timestamp": mocker.ANY, |
| "loc": mocker.ANY, |
| }, |
| { |
| "event": "Task killed!", |
| "level": "error", |
| "logger": "task", |
| "timestamp": mocker.ANY, |
| "loc": mocker.ANY, |
| }, |
| ] |
| |
| @pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True) |
| def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, time_machine): |
| """ |
| Test that ensures the WatchedSubprocess kills the process after |
| MAX_FAILED_HEARTBEATS are exceeded. |
| """ |
| max_failed_heartbeats = 3 |
| min_heartbeat_interval = 5 |
| monkeypatch.setattr( |
| "airflow.sdk.execution_time.supervisor.MAX_FAILED_HEARTBEATS", max_failed_heartbeats |
| ) |
| monkeypatch.setattr( |
| "airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", min_heartbeat_interval |
| ) |
| |
| mock_process = mocker.Mock() |
| mock_process.pid = 12345 |
| |
| # Mock the client heartbeat method to raise an exception |
| mock_client_heartbeat = mocker.Mock(side_effect=Exception("Simulated heartbeat failure")) |
| client = mocker.Mock() |
| client.task_instances.heartbeat = mock_client_heartbeat |
| |
| # Patch the kill method at the class level so we can assert it was called with the correct signal |
| mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill") |
| |
| proc = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=mock_process.pid, |
| stdin=mocker.MagicMock(), |
| client=client, |
| process=mock_process, |
| ) |
| current = min_heartbeat_interval |
| |
| def mock_monotonic(): |
| return current |
| |
| with patch( |
| "airflow.sdk.execution_time.supervisor.time.monotonic", |
| side_effect=mock_monotonic, |
| ): |
| time_now = timezone.datetime(2024, 11, 28, 12, 0, 0) |
| time_machine.move_to(time_now, tick=False) |
| |
| # Simulate sending heartbeats and ensure the process gets killed after max retries |
| for i in range(1, max_failed_heartbeats): |
| proc._send_heartbeat_if_needed() |
| assert proc.failed_heartbeats == i # Increment happens after failure |
| mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid) |
| |
| # Ensure the retry log is present |
| expected_log = { |
| "event": "Failed to send heartbeat. Will be retried", |
| "failed_heartbeats": i, |
| "ti_id": TI_ID, |
| "max_retries": max_failed_heartbeats, |
| "level": "warning", |
| "logger": "supervisor", |
| "timestamp": mocker.ANY, |
| "exc_info": mocker.ANY, |
| "loc": mocker.ANY, |
| } |
| |
| assert expected_log in captured_logs |
| |
| # Advance time by `min_heartbeat_interval` to allow the next heartbeat |
| # time_machine.shift(min_heartbeat_interval) |
| current += min_heartbeat_interval |
| |
| # On the final failure, the process should be killed |
| proc._send_heartbeat_if_needed() |
| |
| assert proc.failed_heartbeats == max_failed_heartbeats |
| mock_kill.assert_called_once_with(signal.SIGTERM, force=True) |
| mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid) |
| assert { |
| "event": "Too many failed heartbeats; terminating process", |
| "level": "error", |
| "failed_heartbeats": max_failed_heartbeats, |
| "logger": "supervisor", |
| "timestamp": mocker.ANY, |
| "loc": mocker.ANY, |
| } in captured_logs |
| |
| @pytest.mark.parametrize( |
| ("terminal_state", "task_end_time_monotonic", "overtime_threshold", "expected_kill"), |
| [ |
| pytest.param( |
| None, |
| 15.0, |
| 10, |
| False, |
| id="no_terminal_state", |
| ), |
| pytest.param(TaskInstanceState.SUCCESS, 15.0, 10, False, id="below_threshold"), |
| pytest.param(TaskInstanceState.SUCCESS, 9.0, 10, True, id="above_threshold"), |
| pytest.param(TaskInstanceState.FAILED, 9.0, 10, True, id="above_threshold_failed_state"), |
| pytest.param(TaskInstanceState.SKIPPED, 9.0, 10, True, id="above_threshold_skipped_state"), |
| pytest.param(TaskInstanceState.SUCCESS, None, 20, False, id="task_end_datetime_none"), |
| ], |
| ) |
| def test_overtime_handling( |
| self, |
| mocker, |
| terminal_state, |
| task_end_time_monotonic, |
| overtime_threshold, |
| expected_kill, |
| monkeypatch, |
| ): |
| """Test handling of overtime under various conditions.""" |
| # Mocking logger since we are only interested that it is called with the expected message |
| # and not the actual log output |
| mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log") |
| |
| # Mock the kill method at the class level so we can assert it was called with the correct signal |
| mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill") |
| |
| # Mock the current monotonic time |
| mocker.patch("time.monotonic", return_value=20.0) |
| |
| # Patch the task overtime threshold |
| monkeypatch.setattr( |
| "airflow.sdk.execution_time.supervisor.TASK_OVERTIME_THRESHOLD", overtime_threshold |
| ) |
| |
| mock_watched_subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.Mock(), |
| process=mocker.Mock(), |
| client=mocker.Mock(), |
| ) |
| |
| # Set the terminal state and task end datetime |
| mock_watched_subprocess._terminal_state = terminal_state |
| mock_watched_subprocess._task_end_time_monotonic = task_end_time_monotonic |
| |
| # Call `wait` to trigger the overtime handling |
| # This will call the `kill` method if the task has been running for too long |
| mock_watched_subprocess._handle_process_overtime_if_needed() |
| |
| # Validate process kill behavior and log messages |
| if expected_kill: |
| mock_kill.assert_called_once_with(signal.SIGTERM, force=True) |
| mock_logger.warning.assert_called_once_with( |
| "Task success overtime reached; terminating process. " |
| "Modify `task_success_overtime` setting in [core] section of " |
| "Airflow configuration to change this limit.", |
| ti_id=TI_ID, |
| ) |
| else: |
| mock_kill.assert_not_called() |
| mock_logger.warning.assert_not_called() |
| |
| @pytest.mark.parametrize( |
| ("signal_to_raise", "log_pattern", "level"), |
| ( |
| pytest.param( |
| signal.SIGKILL, |
| re.compile(r"Process terminated by signal. Likely out of memory error"), |
| "critical", |
| id="kill", |
| ), |
| pytest.param( |
| signal.SIGTERM, |
| re.compile(r"Process terminated by signal. For more information"), |
| "error", |
| id="term", |
| ), |
| pytest.param( |
| signal.SIGSEGV, |
| re.compile(r".*SIGSEGV \(Segmentation Violation\) signal indicates", re.DOTALL), |
| "critical", |
| id="segv", |
| ), |
| ), |
| ) |
| def test_exit_by_signal(self, signal_to_raise, log_pattern, level, cap_structlog, client_with_ti_start): |
| def subprocess_main(): |
| import faulthandler |
| import os |
| |
| # Disable pytest fault handler |
| if faulthandler.is_enabled(): |
| faulthandler.disable() |
| |
| # Ensure we follow the "protocol" and get the startup message before we do anything |
| CommsDecoder()._get_response() |
| |
| os.kill(os.getpid(), signal_to_raise) |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id="4d828a62-a417-4936-a7a6-2b3fabacecab", |
| task_id="b", |
| dag_id="c", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ), |
| client=client_with_ti_start, |
| target=subprocess_main, |
| ) |
| |
| rc = proc.wait() |
| |
| assert { |
| "log_level": level, |
| "event": log_pattern, |
| } in cap_structlog |
| assert rc == -signal_to_raise |
| |
| @pytest.mark.execution_timeout(3) |
| def test_cleanup_sockets_after_delay(self, monkeypatch, mocker, time_machine): |
| """Supervisor should close sockets if EOF events are missed.""" |
| |
| monkeypatch.setattr("airflow.sdk.execution_time.supervisor.SOCKET_CLEANUP_TIMEOUT", 1.0) |
| |
| mock_process = mocker.Mock(pid=12345) |
| |
| time_machine.move_to(time.time(), tick=False) |
| |
| proc = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=mock_process.pid, |
| stdin=mocker.MagicMock(), |
| client=mocker.MagicMock(), |
| process=mock_process, |
| ) |
| |
| proc.selector = mocker.MagicMock() |
| proc.selector.select.return_value = [] |
| |
| proc._exit_code = 0 |
| # Create a fake placeholder in the open socket weakref |
| proc._open_sockets[mocker.MagicMock()] = "test placeholder" |
| proc._process_exit_monotonic = time.time() |
| |
| mocker.patch.object( |
| ActivitySubprocess, |
| "_cleanup_open_sockets", |
| side_effect=lambda: setattr(proc, "_open_sockets", {}), |
| ) |
| |
| time_machine.shift(2) |
| |
| proc._monitor_subprocess() |
| assert len(proc._open_sockets) == 0 |
| |
| |
| class TestWatchedSubprocessKill: |
| @pytest.fixture |
| def mock_process(self, mocker): |
| process = mocker.Mock(spec=psutil.Process) |
| process.pid = 12345 |
| return process |
| |
| @pytest.fixture |
| def watched_subprocess(self, mocker, mock_process): |
| proc = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.Mock(), |
| client=mocker.Mock(), |
| process=mock_process, |
| ) |
| # Mock the selector |
| mock_selector = mocker.Mock(spec=selectors.DefaultSelector) |
| mock_selector.select.return_value = [] |
| |
| # Set the selector on the process |
| proc.selector = mock_selector |
| return proc |
| |
| def test_kill_process_already_exited(self, watched_subprocess, mock_process): |
| """Test behavior when the process has already exited.""" |
| mock_process.wait.side_effect = psutil.NoSuchProcess(pid=1234) |
| watched_subprocess.kill(signal.SIGINT, force=True) |
| |
| mock_process.send_signal.assert_called_once_with(signal.SIGINT) |
| mock_process.wait.assert_called_once() |
| assert watched_subprocess._exit_code == -1 |
| |
| def test_kill_process_custom_signal(self, watched_subprocess, mock_process): |
| """Test that the process is killed with the correct signal.""" |
| mock_process.wait.return_value = 0 |
| |
| signal_to_send = signal.SIGUSR1 |
| watched_subprocess.kill(signal_to_send, force=False) |
| |
| mock_process.send_signal.assert_called_once_with(signal_to_send) |
| mock_process.wait.assert_called_once_with(timeout=0) |
| |
| @pytest.mark.parametrize( |
| ("signal_to_send", "exit_after"), |
| [ |
| pytest.param( |
| signal.SIGINT, |
| signal.SIGINT, |
| id="SIGINT-success-without-escalation", |
| ), |
| pytest.param( |
| signal.SIGINT, |
| signal.SIGTERM, |
| id="SIGINT-escalates-to-SIGTERM", |
| ), |
| pytest.param( |
| signal.SIGINT, |
| None, |
| id="SIGINT-escalates-to-SIGTERM-then-SIGKILL", |
| ), |
| pytest.param( |
| signal.SIGTERM, |
| None, |
| id="SIGTERM-escalates-to-SIGKILL", |
| ), |
| pytest.param( |
| signal.SIGKILL, |
| None, |
| id="SIGKILL-success-without-escalation", |
| ), |
| ], |
| ) |
| def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, client_with_ti_start): |
| def subprocess_main(): |
| import signal |
| |
| def _handler(sig, frame): |
| print(f"Signal {sig} received", file=sys.stderr) |
| if exit_after == sig: |
| sleep(0.1) |
| # We exit 0 as that's what task_runner.py tries hard to do. The only difference if we exit |
| # with non-zero is extra logs |
| exit(0) |
| sleep(5) |
| print("Should not get here") |
| |
| signal.signal(signal.SIGINT, _handler) |
| signal.signal(signal.SIGTERM, _handler) |
| try: |
| CommsDecoder()._get_response() |
| print("Ready") |
| sleep(10) |
| except Exception as e: |
| print(e) |
| # Shouldn't get here |
| exit(5) |
| |
| ti_id = uuid7() |
| |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() |
| ), |
| client=client_with_ti_start, |
| target=subprocess_main, |
| ) |
| |
| # Ensure we get one normal run, to give the proc time to register it's custom sighandler |
| time.sleep(0.1) |
| proc._service_subprocess(max_wait_time=1) |
| proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5, force=True) |
| |
| # Wait for the subprocess to finish |
| assert proc.wait() == exit_after or -signal.SIGKILL |
| exit_after = exit_after or signal.SIGKILL |
| |
| logs = [{"event": m["event"], "logger": m["logger"]} for m in captured_logs] |
| expected_logs = [ |
| {"logger": "task.stdout", "event": "Ready"}, |
| ] |
| # Work out what logs we expect to see |
| if signal_to_send == signal.SIGINT: |
| expected_logs.append({"logger": "task.stderr", "event": "Signal 2 received"}) |
| if signal_to_send == signal.SIGTERM or ( |
| signal_to_send == signal.SIGINT and exit_after != signal.SIGINT |
| ): |
| if signal_to_send == signal.SIGINT: |
| expected_logs.append( |
| { |
| "event": "Process did not terminate in time; escalating", |
| "logger": "supervisor", |
| } |
| ) |
| expected_logs.append({"logger": "task.stderr", "event": "Signal 15 received"}) |
| if exit_after == signal.SIGKILL: |
| if signal_to_send in {signal.SIGINT, signal.SIGTERM}: |
| expected_logs.append( |
| { |
| "event": "Process did not terminate in time; escalating", |
| "logger": "supervisor", |
| } |
| ) |
| |
| expected_logs.extend(({"event": "Process exited", "logger": "supervisor"},)) |
| assert logs == expected_logs |
| |
| def test_service_subprocess(self, watched_subprocess, mock_process, mocker): |
| """Test `_service_subprocess` processes selector events and handles subprocess exit.""" |
| # Given |
| |
| # Mock file objects and handlers |
| mock_stdout = mocker.Mock() |
| mock_stderr = mocker.Mock() |
| |
| # Handlers for stdout and stderr |
| mock_stdout_handler = mocker.Mock(return_value=False) # Simulate EOF for stdout |
| mock_stderr_handler = mocker.Mock(return_value=True) # Continue processing for stderr |
| |
| mock_on_close = mocker.Mock() |
| |
| # Mock selector to return events |
| mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=(mock_stdout_handler, mock_on_close)) |
| mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=(mock_stderr_handler, mock_on_close)) |
| watched_subprocess.selector.select.return_value = [(mock_key_stdout, None), (mock_key_stderr, None)] |
| |
| # Mock to simulate process exited successfully |
| mock_process.wait.return_value = 0 |
| |
| # Our actual test |
| watched_subprocess._service_subprocess(max_wait_time=1.0) |
| |
| # Validations! |
| # Validate selector interactions |
| watched_subprocess.selector.select.assert_called_once_with(timeout=1.0) |
| |
| # Validate handler calls |
| mock_stdout_handler.assert_called_once_with(mock_stdout) |
| mock_stderr_handler.assert_called_once_with(mock_stderr) |
| |
| # Validate unregistering and closing of EOF file object |
| mock_on_close.assert_called_once_with(mock_stdout) |
| |
| # Validate that `_check_subprocess_exit` is called |
| mock_process.wait.assert_called_once_with(timeout=0) |
| |
| def test_max_wait_time_prevents_cpu_spike(self, watched_subprocess, mock_process, monkeypatch): |
| """Test that max_wait_time calculation prevents CPU spike when heartbeat timeout is reached.""" |
| # Mock the configuration to reproduce the CPU spike scenario |
| # Set heartbeat timeout to be very small relative to MIN_HEARTBEAT_INTERVAL |
| monkeypatch.setattr("airflow.sdk.execution_time.supervisor.HEARTBEAT_TIMEOUT", 1) |
| monkeypatch.setattr("airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", 10) |
| |
| # Set up a scenario where the last successful heartbeat was a long time ago |
| # This will cause the heartbeat calculation to result in a negative value |
| mock_process._last_successful_heartbeat = time.time() - 100 # 100 seconds ago |
| |
| # Mock process to still be alive (not exited) |
| mock_process.wait.side_effect = psutil.TimeoutExpired(pid=12345, seconds=0) |
| |
| # Call _service_subprocess which is used in _monitor_subprocess |
| # This tests the max_wait_time calculation directly |
| watched_subprocess._service_subprocess(max_wait_time=0.005) # Very small timeout to verify our fix |
| |
| # Verify that selector.select was called with a minimum timeout of 0.01 |
| # This proves our fix prevents the timeout=0 scenario that causes CPU spike |
| watched_subprocess.selector.select.assert_called_once() |
| call_args = watched_subprocess.selector.select.call_args |
| timeout_arg = call_args[1]["timeout"] if "timeout" in call_args[1] else call_args[0][0] |
| |
| # The timeout should be at least 0.01 (our minimum), never 0 |
| assert timeout_arg >= 0.01, f"Expected timeout >= 0.01, got {timeout_arg}" |
| |
| @pytest.mark.parametrize( |
| ("heartbeat_timeout", "min_interval", "heartbeat_ago", "expected_min_timeout"), |
| [ |
| # Normal case: heartbeat is recent, should use calculated value |
| pytest.param(30, 5, 5, 0.01, id="normal_heartbeat"), |
| # Edge case: heartbeat timeout exceeded, should use minimum |
| pytest.param(10, 20, 50, 0.01, id="heartbeat_timeout_exceeded"), |
| # Bug reproduction case: timeout < interval, heartbeat very old |
| pytest.param(5, 10, 100, 0.01, id="cpu_spike_scenario"), |
| ], |
| ) |
| def test_max_wait_time_calculation_edge_cases( |
| self, |
| watched_subprocess, |
| mock_process, |
| monkeypatch, |
| heartbeat_timeout, |
| min_interval, |
| heartbeat_ago, |
| expected_min_timeout, |
| ): |
| """Test max_wait_time calculation in various edge case scenarios.""" |
| monkeypatch.setattr("airflow.sdk.execution_time.supervisor.HEARTBEAT_TIMEOUT", heartbeat_timeout) |
| monkeypatch.setattr("airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", min_interval) |
| |
| watched_subprocess._last_successful_heartbeat = time.time() - heartbeat_ago |
| mock_process.wait.side_effect = psutil.TimeoutExpired(pid=12345, seconds=0) |
| |
| # Call the method and verify timeout is never less than our minimum |
| watched_subprocess._service_subprocess( |
| max_wait_time=999 |
| ) # Large value, should be overridden by calculation |
| |
| # Extract the timeout that was actually used |
| watched_subprocess.selector.select.assert_called_once() |
| call_args = watched_subprocess.selector.select.call_args |
| actual_timeout = call_args[1]["timeout"] if "timeout" in call_args[1] else call_args[0][0] |
| |
| assert actual_timeout >= expected_min_timeout |
| |
| |
| @dataclass |
| class ClientMock: |
| """Configuration for mocking client method calls.""" |
| |
| method_path: str |
| """Path to the client method to mock (e.g., 'connections.get', 'variables.set').""" |
| |
| args: tuple = field(default_factory=tuple) |
| """Positional arguments the client method should be called with.""" |
| |
| kwargs: dict = field(default_factory=dict) |
| """Keyword arguments the client method should be called with.""" |
| |
| response: Any = None |
| """What the mocked client method should return when called.""" |
| |
| |
| @dataclass |
| class RequestTestCase: |
| """Test case data for request handling tests in `TestHandleRequest` class.""" |
| |
| message: Any |
| """The request message to send to the supervisor (e.g., GetConnection, SetXCom).""" |
| |
| test_id: str |
| """Unique identifier for this test case, used in pytest parameterization.""" |
| |
| client_mock: ClientMock | None = None |
| """Client method mocking configuration. None for messages that don't require client calls.""" |
| |
| expected_body: dict | None = None |
| """Expected response body from supervisor. None if no response body expected.""" |
| |
| mask_secret_args: tuple | None = None |
| """Arguments that should be passed to the secret masker for redaction.""" |
| |
| |
| # Test cases for request handling |
| REQUEST_TEST_CASES = [ |
| RequestTestCase( |
| message=GetConnection(conn_id="test_conn"), |
| test_id="get_connection", |
| client_mock=ClientMock( |
| method_path="connections.get", |
| args=("test_conn",), |
| response=ConnectionResult(conn_id="test_conn", conn_type="mysql"), |
| ), |
| expected_body={"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, |
| ), |
| RequestTestCase( |
| message=GetConnection(conn_id="test_conn"), |
| test_id="get_connection_with_password", |
| client_mock=ClientMock( |
| method_path="connections.get", |
| args=("test_conn",), |
| response=ConnectionResult(conn_id="test_conn", conn_type="mysql", password="password"), |
| ), |
| expected_body={ |
| "conn_id": "test_conn", |
| "conn_type": "mysql", |
| "password": "password", |
| "type": "ConnectionResult", |
| }, |
| mask_secret_args=("password",), |
| ), |
| RequestTestCase( |
| message=GetConnection(conn_id="test_conn"), |
| test_id="get_connection_with_alias", |
| client_mock=ClientMock( |
| method_path="connections.get", |
| args=("test_conn",), |
| response=ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), # type: ignore[call-arg] |
| ), |
| expected_body={ |
| "conn_id": "test_conn", |
| "conn_type": "mysql", |
| "schema": "mysql", |
| "type": "ConnectionResult", |
| }, |
| ), |
| RequestTestCase( |
| message=GetVariable(key="test_key"), |
| test_id="get_variable", |
| client_mock=ClientMock( |
| method_path="variables.get", |
| args=("test_key",), |
| response=VariableResult(key="test_key", value="test_value"), |
| ), |
| expected_body={"key": "test_key", "value": "test_value", "type": "VariableResult"}, |
| mask_secret_args=("test_value", "test_key"), |
| ), |
| RequestTestCase( |
| message=PutVariable(key="test_key", value="test_value", description="test_description"), |
| test_id="set_variable", |
| client_mock=ClientMock( |
| method_path="variables.set", |
| args=("test_key", "test_value", "test_description"), |
| response=OKResponse(ok=True), |
| ), |
| ), |
| RequestTestCase( |
| message=DeleteVariable(key="test_key"), |
| test_id="delete_variable", |
| client_mock=ClientMock( |
| method_path="variables.delete", |
| args=("test_key",), |
| response=OKResponse(ok=True), |
| ), |
| expected_body={"ok": True, "type": "OKResponse"}, |
| ), |
| RequestTestCase( |
| message=DeferTask(next_method="execute_callback", classpath="my-classpath"), |
| test_id="patch_task_instance_to_deferred", |
| client_mock=ClientMock( |
| method_path="task_instances.defer", |
| args=(TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), |
| ), |
| ), |
| RequestTestCase( |
| message=RescheduleTask( |
| reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), |
| end_date=timezone.parse("2024-10-31T12:00:00Z"), |
| ), |
| test_id="patch_task_instance_to_up_for_reschedule", |
| client_mock=ClientMock( |
| method_path="task_instances.reschedule", |
| args=( |
| TI_ID, |
| RescheduleTask( |
| reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), |
| end_date=timezone.parse("2024-10-31T12:00:00Z"), |
| ), |
| ), |
| ), |
| ), |
| RequestTestCase( |
| message=GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), |
| test_id="get_xcom", |
| client_mock=ClientMock( |
| method_path="xcoms.get", |
| args=("test_dag", "test_run", "test_task", "test_key", None, False), |
| response=XComResult(key="test_key", value="test_value"), |
| ), |
| expected_body={"key": "test_key", "value": "test_value", "type": "XComResult"}, |
| ), |
| RequestTestCase( |
| message=GetXCom( |
| dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 |
| ), |
| test_id="get_xcom_map_index", |
| client_mock=ClientMock( |
| method_path="xcoms.get", |
| args=("test_dag", "test_run", "test_task", "test_key", 2, False), |
| response=XComResult(key="test_key", value="test_value"), |
| ), |
| expected_body={"key": "test_key", "value": "test_value", "type": "XComResult"}, |
| ), |
| RequestTestCase( |
| message=GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), |
| test_id="get_xcom_not_found", |
| client_mock=ClientMock( |
| method_path="xcoms.get", |
| args=("test_dag", "test_run", "test_task", "test_key", None, False), |
| response=XComResult(key="test_key", value=None, type="XComResult"), |
| ), |
| expected_body={"key": "test_key", "value": None, "type": "XComResult"}, |
| ), |
| RequestTestCase( |
| message=GetXCom( |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| key="test_key", |
| include_prior_dates=True, |
| ), |
| test_id="get_xcom_include_prior_dates", |
| client_mock=ClientMock( |
| method_path="xcoms.get", |
| args=("test_dag", "test_run", "test_task", "test_key", None, True), |
| response=XComResult(key="test_key", value=None, type="XComResult"), |
| ), |
| expected_body={"key": "test_key", "value": None, "type": "XComResult"}, |
| ), |
| RequestTestCase( |
| message=SetXCom( |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| key="test_key", |
| value='{"key": "test_key", "value": {"key2": "value2"}}', |
| ), |
| client_mock=ClientMock( |
| method_path="xcoms.set", |
| args=( |
| "test_dag", |
| "test_run", |
| "test_task", |
| "test_key", |
| '{"key": "test_key", "value": {"key2": "value2"}}', |
| None, |
| None, |
| ), |
| response=OKResponse(ok=True), |
| ), |
| test_id="set_xcom", |
| ), |
| RequestTestCase( |
| message=SetXCom( |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| key="test_key", |
| value='{"key": "test_key", "value": {"key2": "value2"}}', |
| map_index=2, |
| ), |
| client_mock=ClientMock( |
| method_path="xcoms.set", |
| args=( |
| "test_dag", |
| "test_run", |
| "test_task", |
| "test_key", |
| '{"key": "test_key", "value": {"key2": "value2"}}', |
| 2, |
| None, |
| ), |
| response=OKResponse(ok=True), |
| ), |
| test_id="set_xcom_with_map_index", |
| ), |
| RequestTestCase( |
| message=SetXCom( |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| key="test_key", |
| value='{"key": "test_key", "value": {"key2": "value2"}}', |
| map_index=2, |
| mapped_length=3, |
| ), |
| client_mock=ClientMock( |
| method_path="xcoms.set", |
| args=( |
| "test_dag", |
| "test_run", |
| "test_task", |
| "test_key", |
| '{"key": "test_key", "value": {"key2": "value2"}}', |
| 2, |
| 3, |
| ), |
| response=OKResponse(ok=True), |
| ), |
| test_id="set_xcom_with_map_index_and_mapped_length", |
| ), |
| RequestTestCase( |
| message=DeleteXCom( |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| key="test_key", |
| map_index=2, |
| ), |
| client_mock=ClientMock( |
| method_path="xcoms.delete", |
| args=("test_dag", "test_run", "test_task", "test_key", 2), |
| response=OKResponse(ok=True), |
| ), |
| test_id="delete_xcom", |
| ), |
| RequestTestCase( |
| message=RetryTask( |
| end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" |
| ), |
| client_mock=ClientMock( |
| method_path="task_instances.retry", |
| kwargs={ |
| "id": TI_ID, |
| "end_date": timezone.parse("2024-10-31T12:00:00Z"), |
| "rendered_map_index": "test retry task", |
| }, |
| response=OKResponse(ok=True), |
| ), |
| test_id="up_for_retry", |
| ), |
| RequestTestCase( |
| message=SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), |
| client_mock=ClientMock( |
| method_path="task_instances.set_rtif", |
| args=(TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), |
| response=OKResponse(ok=True), |
| ), |
| test_id="set_rtif", |
| ), |
| RequestTestCase( |
| message=SetRenderedMapIndex(rendered_map_index="Label: task_1"), |
| client_mock=ClientMock( |
| method_path="task_instances.set_rendered_map_index", |
| args=(TI_ID, "Label: task_1"), |
| response=OKResponse(ok=True), |
| ), |
| test_id="set_rendered_map_index", |
| ), |
| RequestTestCase( |
| message=SucceedTask( |
| end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" |
| ), |
| client_mock=ClientMock( |
| method_path="task_instances.succeed", |
| kwargs={ |
| "id": TI_ID, |
| "outlet_events": None, |
| "task_outlets": None, |
| "when": timezone.parse("2024-10-31T12:00:00Z"), |
| "rendered_map_index": "test success task", |
| }, |
| ), |
| test_id="succeed_task", |
| ), |
| RequestTestCase( |
| message=GetAssetByName(name="asset"), |
| expected_body={"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, |
| client_mock=ClientMock( |
| method_path="assets.get", |
| kwargs={"name": "asset"}, |
| response=AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), |
| ), |
| test_id="get_asset_by_name", |
| ), |
| RequestTestCase( |
| message=GetAssetByUri(uri="s3://bucket/obj"), |
| expected_body={"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, |
| client_mock=ClientMock( |
| method_path="assets.get", |
| kwargs={"uri": "s3://bucket/obj"}, |
| response=AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), |
| ), |
| test_id="get_asset_by_uri", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "uri": "s3://bucket/obj", |
| "name": "test", |
| "after": None, |
| "before": None, |
| "limit": None, |
| "ascending": True, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ), |
| ], |
| ), |
| ), |
| test_id="get_asset_events_by_uri_and_name", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAsset( |
| uri="s3://bucket/obj", |
| name="test", |
| after=datetime(2024, 10, 1, 12, 0, 0, tzinfo=timezone.utc), |
| before=datetime(2024, 10, 15, 12, 0, 0, tzinfo=timezone.utc), |
| limit=5, |
| ascending=False, |
| ), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "uri": "s3://bucket/obj", |
| "name": "test", |
| "after": timezone.parse("2024-10-01T12:00:00Z"), |
| "before": timezone.parse("2024-10-15T12:00:00Z"), |
| "limit": 5, |
| "ascending": False, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ), |
| ], |
| ), |
| ), |
| test_id="get_asset_events_by_uri_and_name_with_filters", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAsset(uri="s3://bucket/obj", name=None), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "uri": "s3://bucket/obj", |
| "name": None, |
| "after": None, |
| "before": None, |
| "limit": None, |
| "ascending": True, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ) |
| ], |
| ), |
| ), |
| test_id="get_asset_events_by_uri", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAsset( |
| uri="s3://bucket/obj", |
| name=None, |
| after=datetime(2024, 10, 1, 12, 0, 0, tzinfo=timezone.utc), |
| before=datetime(2024, 10, 15, 12, 0, 0, tzinfo=timezone.utc), |
| limit=5, |
| ascending=False, |
| ), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "uri": "s3://bucket/obj", |
| "name": None, |
| "after": timezone.parse("2024-10-01T12:00:00Z"), |
| "before": timezone.parse("2024-10-15T12:00:00Z"), |
| "limit": 5, |
| "ascending": False, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ) |
| ], |
| ), |
| ), |
| test_id="get_asset_events_by_uri_with_filters", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAsset(uri=None, name="test"), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "uri": None, |
| "name": "test", |
| "after": None, |
| "before": None, |
| "limit": None, |
| "ascending": True, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ) |
| ] |
| ), |
| ), |
| test_id="get_asset_events_by_name", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAsset( |
| uri=None, |
| name="test", |
| after=datetime(2024, 10, 1, 12, 0, 0, tzinfo=timezone.utc), |
| before=datetime(2024, 10, 15, 12, 0, 0, tzinfo=timezone.utc), |
| limit=5, |
| ascending=False, |
| ), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "uri": None, |
| "name": "test", |
| "after": timezone.parse("2024-10-01T12:00:00Z"), |
| "before": timezone.parse("2024-10-15T12:00:00Z"), |
| "limit": 5, |
| "ascending": False, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ) |
| ] |
| ), |
| ), |
| test_id="get_asset_events_by_name_with_filters", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAssetAlias(alias_name="test_alias"), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "alias_name": "test_alias", |
| "after": None, |
| "before": None, |
| "limit": None, |
| "ascending": True, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ) |
| ] |
| ), |
| ), |
| test_id="get_asset_events_by_asset_alias", |
| ), |
| RequestTestCase( |
| message=GetAssetEventByAssetAlias( |
| alias_name="test_alias", |
| after=datetime(2024, 10, 1, 12, 0, 0, tzinfo=timezone.utc), |
| before=datetime(2024, 10, 15, 12, 0, 0, tzinfo=timezone.utc), |
| limit=5, |
| ascending=False, |
| ), |
| expected_body={ |
| "asset_events": [ |
| { |
| "id": 1, |
| "timestamp": timezone.parse("2024-10-31T12:00:00Z"), |
| "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, |
| "created_dagruns": [], |
| } |
| ], |
| "type": "AssetEventsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="asset_events.get", |
| kwargs={ |
| "alias_name": "test_alias", |
| "after": timezone.parse("2024-10-01T12:00:00Z"), |
| "before": timezone.parse("2024-10-15T12:00:00Z"), |
| "limit": 5, |
| "ascending": False, |
| }, |
| response=AssetEventsResult( |
| asset_events=[ |
| AssetEventResponse( |
| id=1, |
| asset=AssetResponse(name="asset", uri="s3://bucket/obj", group="asset"), |
| created_dagruns=[], |
| timestamp=timezone.parse("2024-10-31T12:00:00Z"), |
| ) |
| ] |
| ), |
| ), |
| test_id="get_asset_events_by_asset_alias_with_filters", |
| ), |
| RequestTestCase( |
| message=ValidateInletsAndOutlets(ti_id=TI_ID), |
| expected_body={ |
| "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], |
| "type": "InactiveAssetsResult", |
| }, |
| client_mock=ClientMock( |
| method_path="task_instances.validate_inlets_and_outlets", |
| args=(TI_ID,), |
| response=InactiveAssetsResult( |
| inactive_assets=[AssetProfile(name="asset_name", uri="asset_uri", type="asset")] |
| ), |
| ), |
| test_id="validate_inlets_and_outlets", |
| ), |
| RequestTestCase( |
| message=GetPrevSuccessfulDagRun(ti_id=TI_ID), |
| expected_body={ |
| "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), |
| "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), |
| "start_date": timezone.parse("2025-01-10T12:00:00Z"), |
| "end_date": timezone.parse("2025-01-10T14:00:00Z"), |
| "type": "PrevSuccessfulDagRunResult", |
| }, |
| client_mock=ClientMock( |
| method_path="task_instances.get_previous_successful_dagrun", |
| args=(TI_ID,), |
| response=PrevSuccessfulDagRunResult( |
| start_date=timezone.parse("2025-01-10T12:00:00Z"), |
| end_date=timezone.parse("2025-01-10T14:00:00Z"), |
| data_interval_start=timezone.parse("2025-01-10T12:00:00Z"), |
| data_interval_end=timezone.parse("2025-01-10T14:00:00Z"), |
| ), |
| ), |
| test_id="get_prev_successful_dagrun", |
| ), |
| RequestTestCase( |
| message=TriggerDagRun( |
| dag_id="test_dag", |
| run_id="test_run", |
| conf={"key": "value"}, |
| logical_date=timezone.datetime(2025, 1, 1), |
| reset_dag_run=True, |
| ), |
| expected_body={"ok": True, "type": "OKResponse"}, |
| client_mock=ClientMock( |
| method_path="dag_runs.trigger", |
| args=("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True, None), |
| response=OKResponse(ok=True), |
| ), |
| test_id="dag_run_trigger", |
| ), |
| RequestTestCase( |
| message=TriggerDagRun( |
| dag_id="test_dag", |
| run_id="test_run", |
| conf={"key": "value"}, |
| logical_date=timezone.datetime(2025, 1, 1), |
| reset_dag_run=True, |
| note="Test Note", |
| ), |
| expected_body={"ok": True, "type": "OKResponse"}, |
| client_mock=ClientMock( |
| method_path="dag_runs.trigger", |
| args=("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True, "Test Note"), |
| response=OKResponse(ok=True), |
| ), |
| test_id="dag_run_trigger", |
| ), |
| RequestTestCase( |
| message=TriggerDagRun(dag_id="test_dag", run_id="test_run"), |
| expected_body={"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, |
| client_mock=ClientMock( |
| method_path="dag_runs.trigger", |
| args=("test_dag", "test_run", None, None, False, None), |
| response=ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS), |
| ), |
| test_id="dag_run_trigger_already_exists", |
| ), |
| RequestTestCase( |
| message=GetDagRun(dag_id="test_dag", run_id="test_run"), |
| expected_body={ |
| "dag_id": "test_dag", |
| "run_id": "prev_run", |
| "logical_date": timezone.parse("2024-01-14T12:00:00Z"), |
| "partition_key": None, |
| "run_type": "scheduled", |
| "start_date": timezone.parse("2024-01-15T12:00:00Z"), |
| "run_after": timezone.parse("2024-01-15T12:00:00Z"), |
| "consumed_asset_events": [], |
| "state": "success", |
| "data_interval_start": None, |
| "data_interval_end": None, |
| "end_date": None, |
| "clear_number": 0, |
| "conf": None, |
| "triggering_user_name": None, |
| "type": "DagRunResult", |
| "note": None, |
| }, |
| client_mock=ClientMock( |
| method_path="dag_runs.get_detail", |
| args=("test_dag", "test_run"), |
| response=DagRunResult( |
| dag_id="test_dag", |
| run_id="prev_run", |
| logical_date=timezone.parse("2024-01-14T12:00:00Z"), |
| run_type=DagRunType.SCHEDULED, |
| start_date=timezone.parse("2024-01-15T12:00:00Z"), |
| run_after=timezone.parse("2024-01-15T12:00:00Z"), |
| consumed_asset_events=[], |
| state=DagRunState.SUCCESS, |
| triggering_user_name=None, |
| ), |
| ), |
| test_id="get_dag_run", |
| ), |
| RequestTestCase( |
| message=GetDagRunState(dag_id="test_dag", run_id="test_run"), |
| expected_body={"state": "running", "type": "DagRunStateResult"}, |
| client_mock=ClientMock( |
| method_path="dag_runs.get_state", |
| args=("test_dag", "test_run"), |
| response=DagRunStateResult(state=DagRunState.RUNNING), |
| ), |
| test_id="get_dag_run_state", |
| ), |
| RequestTestCase( |
| message=GetPreviousDagRun( |
| dag_id="test_dag", |
| logical_date=timezone.parse("2024-01-15T12:00:00Z"), |
| ), |
| expected_body={ |
| "dag_run": { |
| "dag_id": "test_dag", |
| "run_id": "prev_run", |
| "logical_date": timezone.parse("2024-01-14T12:00:00Z"), |
| "partition_key": None, |
| "run_type": "scheduled", |
| "start_date": timezone.parse("2024-01-15T12:00:00Z"), |
| "run_after": timezone.parse("2024-01-15T12:00:00Z"), |
| "consumed_asset_events": [], |
| "state": "success", |
| "data_interval_start": None, |
| "data_interval_end": None, |
| "end_date": None, |
| "clear_number": 0, |
| "conf": None, |
| "triggering_user_name": None, |
| "note": None, |
| }, |
| "type": "PreviousDagRunResult", |
| }, |
| client_mock=ClientMock( |
| method_path="dag_runs.get_previous", |
| kwargs={ |
| "dag_id": "test_dag", |
| "logical_date": timezone.parse("2024-01-15T12:00:00Z"), |
| "state": None, |
| }, |
| response=PreviousDagRunResult( |
| dag_run=DagRun( |
| dag_id="test_dag", |
| run_id="prev_run", |
| logical_date=timezone.parse("2024-01-14T12:00:00Z"), |
| run_type=DagRunType.SCHEDULED, |
| start_date=timezone.parse("2024-01-15T12:00:00Z"), |
| run_after=timezone.parse("2024-01-15T12:00:00Z"), |
| consumed_asset_events=[], |
| state=DagRunState.SUCCESS, |
| triggering_user_name=None, |
| ) |
| ), |
| ), |
| test_id="get_previous_dagrun", |
| ), |
| RequestTestCase( |
| message=GetPreviousDagRun( |
| dag_id="test_dag", |
| logical_date=timezone.parse("2024-01-15T12:00:00Z"), |
| state="success", |
| ), |
| expected_body={ |
| "dag_run": None, |
| "type": "PreviousDagRunResult", |
| }, |
| client_mock=ClientMock( |
| method_path="dag_runs.get_previous", |
| kwargs={ |
| "dag_id": "test_dag", |
| "logical_date": timezone.parse("2024-01-15T12:00:00Z"), |
| "state": "success", |
| }, |
| response=PreviousDagRunResult(dag_run=None), |
| ), |
| test_id="get_previous_dagrun_with_state", |
| ), |
| RequestTestCase( |
| message=GetPreviousTI( |
| dag_id="test_dag", |
| task_id="test_task", |
| logical_date=timezone.parse("2024-01-15T12:00:00Z"), |
| map_index=0, |
| state=TaskInstanceState.SUCCESS, |
| ), |
| expected_body={ |
| "task_instance": { |
| "task_id": "test_task", |
| "dag_id": "test_dag", |
| "run_id": "prev_run", |
| "logical_date": timezone.parse("2024-01-14T12:00:00Z"), |
| "start_date": timezone.parse("2024-01-14T12:05:00Z"), |
| "end_date": timezone.parse("2024-01-14T12:10:00Z"), |
| "state": "success", |
| "try_number": 1, |
| "map_index": 0, |
| "duration": 300.0, |
| }, |
| "type": "PreviousTIResult", |
| }, |
| client_mock=ClientMock( |
| method_path="task_instances.get_previous", |
| kwargs={ |
| "dag_id": "test_dag", |
| "task_id": "test_task", |
| "logical_date": timezone.parse("2024-01-15T12:00:00Z"), |
| "map_index": 0, |
| "state": TaskInstanceState.SUCCESS, |
| }, |
| response=PreviousTIResult( |
| task_instance=PreviousTIResponse( |
| task_id="test_task", |
| dag_id="test_dag", |
| run_id="prev_run", |
| logical_date=timezone.parse("2024-01-14T12:00:00Z"), |
| start_date=timezone.parse("2024-01-14T12:05:00Z"), |
| end_date=timezone.parse("2024-01-14T12:10:00Z"), |
| state="success", |
| try_number=1, |
| map_index=0, |
| duration=300.0, |
| ) |
| ), |
| ), |
| test_id="get_previous_ti", |
| ), |
| RequestTestCase( |
| message=GetTaskRescheduleStartDate(ti_id=TI_ID), |
| expected_body={ |
| "start_date": timezone.parse("2024-10-31T12:00:00Z"), |
| "type": "TaskRescheduleStartDate", |
| }, |
| client_mock=ClientMock( |
| method_path="task_instances.get_reschedule_start_date", |
| args=(TI_ID, 1), |
| response=TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")), |
| ), |
| test_id="get_task_reschedule_start_date", |
| ), |
| RequestTestCase( |
| message=GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), |
| expected_body={"count": 2, "type": "TICount"}, |
| client_mock=ClientMock( |
| method_path="task_instances.get_count", |
| kwargs={ |
| "dag_id": "test_dag", |
| "map_index": None, |
| "logical_dates": None, |
| "run_ids": None, |
| "states": None, |
| "task_group_id": None, |
| "task_ids": ["task1", "task2"], |
| }, |
| response=TICount(count=2), |
| ), |
| test_id="get_ti_count", |
| ), |
| RequestTestCase( |
| message=GetDRCount(dag_id="test_dag", states=["success", "failed"]), |
| expected_body={"count": 2, "type": "DRCount"}, |
| client_mock=ClientMock( |
| method_path="dag_runs.get_count", |
| kwargs={ |
| "dag_id": "test_dag", |
| "logical_dates": None, |
| "run_ids": None, |
| "states": ["success", "failed"], |
| }, |
| response=DRCount(count=2), |
| ), |
| test_id="get_dr_count", |
| ), |
| RequestTestCase( |
| message=GetTaskStates(dag_id="test_dag", task_group_id="test_group"), |
| expected_body={ |
| "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, |
| "type": "TaskStatesResult", |
| }, |
| client_mock=ClientMock( |
| method_path="task_instances.get_task_states", |
| kwargs={ |
| "dag_id": "test_dag", |
| "map_index": None, |
| "task_ids": None, |
| "logical_dates": None, |
| "run_ids": None, |
| "task_group_id": "test_group", |
| }, |
| response=TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), |
| ), |
| test_id="get_task_states", |
| ), |
| RequestTestCase( |
| message=GetXComSequenceItem( |
| key="test_key", |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| offset=0, |
| ), |
| expected_body={"root": "test_value", "type": "XComSequenceIndexResult"}, |
| client_mock=ClientMock( |
| method_path="xcoms.get_sequence_item", |
| args=("test_dag", "test_run", "test_task", "test_key", 0), |
| response=XComSequenceIndexResult(root="test_value"), |
| ), |
| test_id="get_xcom_seq_item", |
| ), |
| RequestTestCase( |
| message=GetXComSequenceItem( |
| key="test_key", |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| offset=2, |
| ), |
| expected_body={"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, |
| client_mock=ClientMock( |
| method_path="xcoms.get_sequence_item", |
| args=("test_dag", "test_run", "test_task", "test_key", 2), |
| response=ErrorResponse(error=ErrorType.XCOM_NOT_FOUND), |
| ), |
| test_id="get_xcom_seq_item_not_found", |
| ), |
| RequestTestCase( |
| message=GetXComSequenceSlice( |
| key="test_key", |
| dag_id="test_dag", |
| run_id="test_run", |
| task_id="test_task", |
| start=None, |
| stop=None, |
| step=None, |
| include_prior_dates=False, |
| ), |
| expected_body={"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, |
| client_mock=ClientMock( |
| method_path="xcoms.get_sequence_slice", |
| args=("test_dag", "test_run", "test_task", "test_key", None, None, None, False), |
| response=XComSequenceSliceResult(root=["foo", "bar"]), |
| ), |
| test_id="get_xcom_seq_slice", |
| ), |
| RequestTestCase( |
| message=TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), |
| test_id="patch_task_instance_to_skipped", |
| ), |
| RequestTestCase( |
| message=CreateHITLDetailPayload( |
| ti_id=TI_ID, |
| options=["Approve", "Reject"], |
| subject="This is subject", |
| body="This is body", |
| defaults=["Approve"], |
| multiple=False, |
| params={}, |
| ), |
| expected_body={ |
| "ti_id": str(TI_ID), |
| "options": ["Approve", "Reject"], |
| "subject": "This is subject", |
| "body": "This is body", |
| "defaults": ["Approve"], |
| "params": {}, |
| "type": "HITLDetailRequestResult", |
| }, |
| client_mock=ClientMock( |
| method_path="hitl.add_response", |
| kwargs={ |
| "body": "This is body", |
| "defaults": ["Approve"], |
| "multiple": False, |
| "options": ["Approve", "Reject"], |
| "params": {}, |
| "assigned_users": None, |
| "subject": "This is subject", |
| "ti_id": TI_ID, |
| }, |
| response=HITLDetailRequestResult( |
| ti_id=TI_ID, |
| options=["Approve", "Reject"], |
| subject="This is subject", |
| body="This is body", |
| defaults=["Approve"], |
| multiple=False, |
| params={}, |
| ), |
| ), |
| test_id="create_hitl_detail_payload", |
| ), |
| RequestTestCase( |
| message=MaskSecret(value=["iter1", "iter2", {"key": "value"}], name="test_secret"), |
| mask_secret_args=(["iter1", "iter2", {"key": "value"}], "test_secret"), |
| test_id="mask_secret_list", |
| ), |
| RequestTestCase( |
| message=GetXComCount(key="test_key", dag_id="test_dag", run_id="test_run", task_id="test_task"), |
| expected_body={"len": 5, "type": "XComLengthResponse"}, |
| client_mock=ClientMock( |
| method_path="xcoms.head", |
| args=("test_dag", "test_run", "test_task", "test_key"), |
| response=XComCountResponse(len=5), |
| ), |
| test_id="get_xcom_count", |
| ), |
| RequestTestCase( |
| message=ResendLoggingFD(), |
| expected_body={"fds": mock.ANY, "type": "SentFDs"}, |
| test_id="resend_logging_fd", |
| ), |
| RequestTestCase( |
| message=SkipDownstreamTasks(tasks=["task1", "task2"]), |
| client_mock=ClientMock( |
| method_path="task_instances.skip_downstream_tasks", |
| args=(TI_ID, SkipDownstreamTasks(tasks=["task1", "task2"])), |
| response=OKResponse(ok=True), |
| ), |
| test_id="skip_downstream_tasks", |
| ), |
| RequestTestCase( |
| message=GetTaskBreadcrumbs(dag_id="test_dag", run_id="test_run"), |
| client_mock=ClientMock( |
| method_path="task_instances.get_task_breakcrumbs", |
| kwargs={"dag_id": "test_dag", "run_id": "test_run"}, |
| response=TaskBreadcrumbsResult( |
| breadcrumbs=[ |
| { |
| "task_id": "test_task", |
| "map_index": 2, |
| "state": "success", |
| "operator": "PythonOperator", |
| "duration": 432.0, |
| }, |
| ], |
| ), |
| ), |
| expected_body={ |
| "breadcrumbs": [ |
| { |
| "task_id": "test_task", |
| "map_index": 2, |
| "state": "success", |
| "operator": "PythonOperator", |
| "duration": 432.0, |
| }, |
| ], |
| "type": "TaskBreadcrumbsResult", |
| }, |
| test_id="get_task_breadcrumbs", |
| ), |
| ] |
| |
| |
| class TestHandleRequest: |
| @pytest.fixture |
| def watched_subprocess(self, mocker): |
| read_end, write_end = socket.socketpair() |
| |
| subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=write_end, |
| client=mocker.Mock(), |
| process=mocker.Mock(), |
| ) |
| |
| return subprocess, read_end |
| |
| @patch("airflow.sdk.execution_time.supervisor.mask_secret") |
| @pytest.mark.parametrize("test_case", REQUEST_TEST_CASES, ids=lambda tc: tc.test_id) |
| def test_handle_requests( |
| self, |
| mock_mask_secret, |
| watched_subprocess, |
| mocker, |
| time_machine, |
| test_case: RequestTestCase, |
| ): |
| """ |
| Test handling of different messages to the subprocess. For any new message type, add a |
| new parameter set to the `@pytest.mark.parametrize` decorator. |
| |
| For each message type, this test: |
| |
| 1. Sends the message to the subprocess. |
| 2. Verifies that the correct client method is called with the expected argument. |
| 3. Checks that the buffer is updated with the expected response. |
| 4. Verifies that the response is correctly decoded. |
| """ |
| # Extract values from test_case |
| message = test_case.message |
| expected_body = test_case.expected_body |
| client_mock = test_case.client_mock |
| mask_secret_args = test_case.mask_secret_args |
| |
| # Rest of test implementation (copied from original) |
| watched_subprocess, read_socket = watched_subprocess |
| |
| # Mock the client method. E.g. `client.variables.get` or `client.connections.get` |
| if client_mock: |
| mock_client_method = attrgetter(client_mock.method_path)(watched_subprocess.client) |
| mock_client_method.return_value = client_mock.response |
| |
| # Simulate the generator |
| generator = watched_subprocess.handle_requests(log=mocker.Mock()) |
| # Initialize the generator |
| next(generator) |
| |
| req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) |
| generator.send(req_frame) |
| |
| if mask_secret_args is not None: |
| mock_mask_secret.assert_called_with(*mask_secret_args) |
| |
| time_machine.move_to(timezone.datetime(2024, 10, 31), tick=False) |
| |
| # Verify the correct client method was called |
| if client_mock: |
| mock_client_method.assert_called_once_with(*client_mock.args, **client_mock.kwargs) |
| |
| # Read response from the read end of the socket |
| read_socket.settimeout(0.1) |
| frame_len = int.from_bytes(read_socket.recv(4), "big") |
| bytes = read_socket.recv(frame_len) |
| frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) |
| |
| assert frame.id == req_frame.id |
| |
| # Verify the response was added to the buffer |
| assert frame.body == expected_body |
| |
| # Verify the response is correctly decoded |
| # This is important because the subprocess/task runner will read the response |
| # and deserialize it to the correct message type |
| |
| if frame.body is not None and client_mock: |
| decoder = CommsDecoder(socket=None).body_decoder # type: ignore[var-annotated, arg-type] |
| assert decoder.validate_python(frame.body) == client_mock.response |
| |
| def test_all_to_supervisor_messages_are_covered(self): |
| """Ensure all ToSupervisor message types have test coverage.""" |
| |
| # Extract the individual message types from the Union |
| union_type = ToSupervisor.__args__[0] |
| supervisor_message_types = set(union_type.__args__) |
| |
| # Get all message types covered in our test cases |
| tested_message_types = {type(test_case.message) for test_case in REQUEST_TEST_CASES} |
| |
| # Message types which are excluded for a good reason |
| excluded_message_types = { |
| GetHITLDetailResponse, # Only used in Triggerer, not needed in worker |
| UpdateHITLDetail, # Only used in Triggerer, not needed in worker |
| } |
| |
| untested_types = supervisor_message_types - tested_message_types - excluded_message_types |
| |
| # Assert all types are covered |
| assert not untested_types, ( |
| f"Missing test coverage for {len(untested_types)}/{len(supervisor_message_types)} " |
| f"ToSupervisor message types:\n" |
| + "\n".join(f" - {t.__name__}" for t in sorted(untested_types, key=lambda x: x.__name__)) |
| + "\n\nPlease add test cases to REQUEST_TEST_CASES." |
| ) |
| |
| def test_handle_requests_api_server_error(self, watched_subprocess, mocker): |
| """Test that API server errors are properly handled and sent back to the task.""" |
| |
| # Unpack subprocess and the reader socket |
| watched_subprocess, read_socket = watched_subprocess |
| |
| error = ServerResponseError( |
| message="API Server Error", |
| request=httpx.Request("GET", "http://test"), |
| response=httpx.Response(500, json={"detail": "Internal Server Error"}), |
| ) |
| |
| mock_client_method = mocker.Mock(side_effect=error) |
| watched_subprocess.client.task_instances.succeed = mock_client_method |
| |
| # Initialize and send message |
| generator = watched_subprocess.handle_requests(log=mocker.Mock()) |
| |
| next(generator) |
| |
| msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")) |
| req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg.model_dump()) |
| generator.send(req_frame) |
| |
| # Read response from the read end of the socket |
| read_socket.settimeout(0.1) |
| frame_len = int.from_bytes(read_socket.recv(4), "big") |
| bytes = read_socket.recv(frame_len) |
| frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) |
| |
| assert frame.id == req_frame.id |
| |
| assert frame.error == { |
| "error": "API_SERVER_ERROR", |
| "detail": { |
| "status_code": 500, |
| "message": "API Server Error", |
| "detail": {"detail": "Internal Server Error"}, |
| }, |
| "type": "ErrorResponse", |
| } |
| |
| # Verify the error can be decoded correctly |
| comms = CommsDecoder(socket=None) |
| with pytest.raises(AirflowRuntimeError) as exc_info: |
| comms._from_frame(frame) |
| |
| assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR |
| assert exc_info.value.error.detail == { |
| "status_code": error.response.status_code, |
| "message": str(error), |
| "detail": error.response.json(), |
| } |
| |
| |
| class TestSetSupervisorComms: |
| class DummyComms: |
| pass |
| |
| @pytest.fixture(autouse=True) |
| def cleanup_supervisor_comms(self): |
| # Ensure clean state before/after test |
| if hasattr(task_runner, "SUPERVISOR_COMMS"): |
| delattr(task_runner, "SUPERVISOR_COMMS") |
| yield |
| if hasattr(task_runner, "SUPERVISOR_COMMS"): |
| delattr(task_runner, "SUPERVISOR_COMMS") |
| |
| def test_set_supervisor_comms_overrides_and_restores(self): |
| task_runner.SUPERVISOR_COMMS = self.DummyComms() |
| original = task_runner.SUPERVISOR_COMMS |
| replacement = self.DummyComms() |
| |
| with set_supervisor_comms(replacement): |
| assert task_runner.SUPERVISOR_COMMS is replacement |
| assert task_runner.SUPERVISOR_COMMS is original |
| |
| def test_set_supervisor_comms_sets_temporarily_when_not_set(self): |
| assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| replacement = self.DummyComms() |
| |
| with set_supervisor_comms(replacement): |
| assert task_runner.SUPERVISOR_COMMS is replacement |
| assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| |
| def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): |
| assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| |
| # This will delete an attribute that isn't set, and restore it likewise |
| with set_supervisor_comms(None): |
| assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| |
| assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| |
| |
| class TestInProcessTestSupervisor: |
| def test_inprocess_supervisor_comms_roundtrip(self): |
| """ |
| Test that InProcessSupervisorComms correctly sends a message to the supervisor, |
| and that the supervisor's response is received via the message queue. |
| |
| This verifies the end-to-end communication flow: |
| - send_request() dispatches a message to the supervisor |
| - the supervisor handles the request and appends a response via send_msg() |
| - get_message() returns the enqueued response |
| |
| This test mocks the supervisor's `_handle_request()` method to simulate |
| a simple echo-style response, avoiding full task execution. |
| """ |
| |
| class MinimalSupervisor(InProcessTestSupervisor): |
| def _handle_request(self, msg, log, req_id): |
| resp = VariableResult(key=msg.key, value="value") |
| self.send_msg(resp, req_id) |
| |
| supervisor = MinimalSupervisor( |
| id="test", |
| pid=123, |
| process=MagicMock(), |
| process_log=MagicMock(), |
| client=MagicMock(), |
| ) |
| comms = InProcessSupervisorComms(supervisor=supervisor) |
| supervisor.comms = comms |
| |
| test_msg = GetVariable(key="test_key") |
| |
| response = comms.send(test_msg) |
| |
| # Ensure we got back what we expect |
| assert isinstance(response, VariableResult) |
| assert response.value == "value" |
| |
| |
| class TestInProcessClient: |
| def test_no_retries(self): |
| called = 0 |
| |
| def noop_handler(request: httpx.Request) -> httpx.Response: |
| nonlocal called |
| called += 1 |
| return httpx.Response(500) |
| |
| transport = httpx.MockTransport(noop_handler) |
| client = InProcessTestSupervisor._Client( |
| base_url="http://local.invalid", token="", transport=transport |
| ) |
| |
| with pytest.raises(httpx.HTTPStatusError): |
| client.get("/goo") |
| |
| assert called == 1 |
| |
| |
| @pytest.mark.parametrize( |
| ("remote_logging", "remote_conn", "expected_env"), |
| ( |
| pytest.param(True, "", "AIRFLOW_CONN_AWS_DEFAULT", id="no-conn-id"), |
| pytest.param(True, "aws_default", "AIRFLOW_CONN_AWS_DEFAULT", id="explicit-default"), |
| pytest.param(True, "my_aws", "AIRFLOW_CONN_MY_AWS", id="other"), |
| pytest.param(False, "", "", id="no-remote-logging"), |
| ), |
| ) |
| def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch, mocker): |
| # This doesn't strictly need the AWS provider, but it does need something that |
| # airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about |
| pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed") |
| |
| # This test is a little bit overly specific to how the logging is currently configured :/ |
| monkeypatch.delitem(sys.modules, "airflow.logging_config") |
| monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False) |
| monkeypatch.delitem(sys.modules, "airflow.sdk.log", raising=False) |
| |
| def handle_request(request: httpx.Request) -> httpx.Response: |
| return httpx.Response( |
| status_code=200, |
| json={ |
| # Minimal enough to pass validation, we don't care what fields are in here for the tests |
| "conn_id": remote_conn, |
| "conn_type": "aws", |
| }, |
| ) |
| |
| # Patch configurations in both airflow-core and task-sdk due to shared library refactoring. |
| # |
| # conf_vars() patches airflow.configuration.conf (airflow-core): |
| # - remote_logging: needed by airflow_local_settings.py to decide whether to set up REMOTE_TASK_LOG |
| # - remote_base_log_folder: needed by airflow_local_settings.py to create the CloudWatch handler |
| # |
| # task_sdk_conf_vars() patches airflow.sdk.configuration.conf (task-sdk): |
| # - remote_log_conn_id: needed by load_remote_conn_id() to return the correct connection id |
| with conf_vars( |
| { |
| ("logging", "remote_logging"): str(remote_logging), |
| ("logging", "remote_base_log_folder"): "cloudwatch://arn:aws:logs:::log-group:test", |
| ("logging", "remote_log_conn_id"): remote_conn, |
| } |
| ): |
| with conf_vars( |
| { |
| ("logging", "remote_log_conn_id"): remote_conn, |
| } |
| ): |
| env = os.environ.copy() |
| client = make_client(transport=httpx.MockTransport(handle_request)) |
| |
| with _remote_logging_conn(client): |
| new_keys = os.environ.keys() - env.keys() |
| if remote_logging: |
| # _remote_logging_conn sets both the connection env var and _AIRFLOW_PROCESS_CONTEXT |
| assert new_keys == {expected_env, "_AIRFLOW_PROCESS_CONTEXT"} |
| else: |
| assert not new_keys |
| |
| if remote_logging and expected_env: |
| connection_available = {"available": False, "conn_uri": None} |
| |
| def mock_upload_to_remote(process_log, ti): |
| connection_available["available"] = expected_env in os.environ |
| connection_available["conn_uri"] = os.environ.get(expected_env) |
| |
| mocker.patch("airflow.sdk.log.upload_to_remote", side_effect=mock_upload_to_remote) |
| |
| activity_subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.MagicMock(), |
| client=client, |
| process=mocker.MagicMock(), |
| ) |
| activity_subprocess.ti = mocker.MagicMock() |
| |
| activity_subprocess._upload_logs() |
| |
| assert connection_available["available"], ( |
| f"Connection {expected_env} was not available during upload_to_remote call" |
| ) |
| assert connection_available["conn_uri"] is not None, "Connection URI was None during upload" |
| |
| |
| def test_remote_logging_conn_sets_process_context(monkeypatch, mocker): |
| """ |
| Test that _remote_logging_conn sets _AIRFLOW_PROCESS_CONTEXT=client. |
| """ |
| pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed") |
| from airflow.models.connection import Connection as CoreConnection |
| from airflow.sdk.definitions.connection import Connection as SDKConnection |
| |
| monkeypatch.delitem(sys.modules, "airflow.logging_config") |
| monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False) |
| monkeypatch.delitem(sys.modules, "airflow.sdk.log", raising=False) |
| |
| conn_id = "s3_conn_logs" |
| conn_uri = "aws:///?region_name=us-east-1" |
| |
| def handle_request(request: httpx.Request) -> httpx.Response: |
| return httpx.Response( |
| status_code=200, |
| json={ |
| "conn_id": conn_id, |
| "conn_type": "aws", |
| "host": None, |
| "login": None, |
| "password": None, |
| "port": None, |
| "schema": None, |
| "extra": '{"region_name": "us-east-1"}', |
| }, |
| ) |
| |
| with conf_vars( |
| { |
| ("logging", "remote_logging"): "True", |
| ("logging", "remote_base_log_folder"): "s3://bucket/logs", |
| ("logging", "remote_log_conn_id"): conn_id, |
| } |
| ): |
| with conf_vars( |
| { |
| ("logging", "remote_log_conn_id"): conn_id, |
| } |
| ): |
| client = make_client(transport=httpx.MockTransport(handle_request)) |
| |
| assert os.getenv("_AIRFLOW_PROCESS_CONTEXT") is None |
| |
| conn_env_key = f"AIRFLOW_CONN_{conn_id.upper()}" |
| |
| with _remote_logging_conn(client): |
| assert os.getenv("_AIRFLOW_PROCESS_CONTEXT") == "client" |
| |
| assert conn_env_key in os.environ |
| stored_uri = os.environ[conn_env_key] |
| assert stored_uri == conn_uri |
| |
| # Verify that Connection.get() uses SDK Connection class when _AIRFLOW_PROCESS_CONTEXT=client |
| # Without _AIRFLOW_PROCESS_CONTEXT=client, _get_connection_class() would return core |
| # Connection. While core Connection can handle URI deserialization via its __init__, |
| # using SDK Connection ensures consistency and proper behavior in supervisor context. |
| from airflow.sdk.execution_time.context import _get_connection |
| |
| retrieved_conn = _get_connection(conn_id) |
| |
| assert isinstance(retrieved_conn, SDKConnection) |
| assert not isinstance(retrieved_conn, CoreConnection) |
| assert retrieved_conn.conn_id == conn_id |
| assert retrieved_conn.conn_type == "aws" |
| |
| # Verify _AIRFLOW_PROCESS_CONTEXT and env var is cleaned up |
| assert os.getenv("_AIRFLOW_PROCESS_CONTEXT") is None |
| assert conn_env_key not in os.environ |
| |
| |
| class TestSignalRetryLogic: |
| """Test retry logic for exit codes (signals and non-signal failures) in ActivitySubprocess.""" |
| |
| @pytest.mark.parametrize( |
| "signal", |
| [ |
| signal.SIGTERM, |
| signal.SIGKILL, |
| signal.SIGABRT, |
| signal.SIGSEGV, |
| ], |
| ) |
| def test_signals_with_retry(self, mocker, signal): |
| """Test that signals with task retries.""" |
| mock_watched_subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.Mock(), |
| process=mocker.Mock(), |
| client=mocker.Mock(), |
| ) |
| |
| mock_watched_subprocess._exit_code = -signal |
| mock_watched_subprocess._should_retry = True |
| |
| result = mock_watched_subprocess.final_state |
| assert result == TaskInstanceState.UP_FOR_RETRY |
| |
| @pytest.mark.parametrize( |
| "signal", |
| [ |
| signal.SIGKILL, |
| signal.SIGTERM, |
| signal.SIGABRT, |
| signal.SIGSEGV, |
| ], |
| ) |
| def test_signals_without_retry_always_fail(self, mocker, signal): |
| """Test that signals without task retries enabled always fail.""" |
| mock_watched_subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.Mock(), |
| process=mocker.Mock(), |
| client=mocker.Mock(), |
| ) |
| mock_watched_subprocess._should_retry = False |
| mock_watched_subprocess._exit_code = -signal |
| |
| result = mock_watched_subprocess.final_state |
| assert result == TaskInstanceState.FAILED |
| |
| def test_non_signal_exit_code_with_retry_goes_to_up_for_retry(self, mocker): |
| """Test that non-signal exit codes with retries enabled go to UP_FOR_RETRY.""" |
| mock_watched_subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.Mock(), |
| process=mocker.Mock(), |
| client=mocker.Mock(), |
| ) |
| mock_watched_subprocess._exit_code = 1 |
| mock_watched_subprocess._should_retry = True |
| |
| assert mock_watched_subprocess.final_state == TaskInstanceState.UP_FOR_RETRY |
| |
| def test_non_signal_exit_code_without_retry_goes_to_failed(self, mocker): |
| """Test that non-signal exit codes without retries enabled go to FAILED.""" |
| mock_watched_subprocess = ActivitySubprocess( |
| process_log=mocker.MagicMock(), |
| id=TI_ID, |
| pid=12345, |
| stdin=mocker.Mock(), |
| process=mocker.Mock(), |
| client=mocker.Mock(), |
| ) |
| mock_watched_subprocess._exit_code = 1 |
| mock_watched_subprocess._should_retry = False |
| |
| assert mock_watched_subprocess.final_state == TaskInstanceState.FAILED |
| |
| |
| def test_remote_logging_conn_caches_connection_not_client(monkeypatch): |
| """Test that connection caching doesn't retain API client references.""" |
| import gc |
| import weakref |
| |
| monkeypatch.delitem(sys.modules, "airflow.logging_config") |
| monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False) |
| monkeypatch.delitem(sys.modules, "airflow.sdk.log", raising=False) |
| |
| from airflow.sdk.execution_time import supervisor |
| |
| class ExampleBackend: |
| def __init__(self): |
| self.calls = 0 |
| |
| def get_connection(self, conn_id: str): |
| self.calls += 1 |
| from airflow.sdk.definitions.connection import Connection |
| |
| return Connection(conn_id=conn_id, conn_type="example") |
| |
| backend = ExampleBackend() |
| monkeypatch.setattr(supervisor, "ensure_secrets_backend_loaded", lambda: [backend]) |
| monkeypatch.delenv("AIRFLOW_CONN_TEST_CONN", raising=False) |
| |
| with conf_vars( |
| { |
| ("logging", "remote_logging"): "True", |
| ("logging", "remote_base_log_folder"): "s3://bucket/logs", |
| ("logging", "remote_log_conn_id"): "test_conn", |
| } |
| ): |
| |
| def noop_request(request: httpx.Request) -> httpx.Response: |
| return httpx.Response(200) |
| |
| clients = [] |
| for _ in range(3): |
| client = make_client(transport=httpx.MockTransport(noop_request)) |
| clients.append(weakref.ref(client)) |
| with _remote_logging_conn(client): |
| pass |
| client.close() |
| del client |
| |
| gc.collect() |
| assert backend.calls == 1, "Connection should be cached, not fetched multiple times" |
| assert all(ref() is None for ref in clients), "Client instances should be garbage collected" |
| |
| |
| def test_process_log_messages_from_subprocess(monkeypatch, caplog): |
| from airflow.sdk._shared.logging.structlog import PER_LOGGER_LEVELS |
| |
| read_end, write_end = socket.socketpair() |
| |
| # Set global level at warning |
| monkeypatch.setitem(PER_LOGGER_LEVELS, "", logging.WARNING) |
| output_log = structlog.get_logger() |
| |
| gen = process_log_messages_from_subprocess(loggers=(output_log,)) |
| |
| # We need to start up the generator to get it to the point it's at waiting on the yield |
| next(gen) |
| |
| # Now we can send in messages to it. |
| gen.send(b'{"level": "debug", "event": "A debug"}\n') |
| gen.send(b'{"level": "error", "event": "An error"}\n') |
| |
| assert caplog.record_tuples == [ |
| (None, logging.DEBUG, "A debug"), |
| (None, logging.ERROR, "An error"), |
| ] |
| |
| |
| def test_reinit_supervisor_comms(monkeypatch, client_with_ti_start, caplog): |
| def subprocess_main(): |
| # This is run in the subprocess! |
| |
| # Ensure we follow the "protocol" and get the startup message before we do anything else |
| c = CommsDecoder() |
| c._get_response() |
| |
| # This mirrors what the VirtualEnvProvider puts in it's script |
| script = """ |
| import os |
| import sys |
| import structlog |
| |
| from airflow.sdk import Connection |
| from airflow.sdk.execution_time.task_runner import reinit_supervisor_comms |
| |
| reinit_supervisor_comms() |
| |
| Connection.get("a") |
| print("ok") |
| sys.stdout.flush() |
| |
| structlog.get_logger().info("is connected") |
| """ |
| # Now we launch a new process, as VirtualEnvOperator will do |
| subprocess.check_call([sys.executable, "-c", dedent(script)]) |
| |
| client_with_ti_start.connections.get.return_value = ConnectionResult( |
| conn_id="test_conn", conn_type="mysql", login="a", password="password1" |
| ) |
| proc = ActivitySubprocess.start( |
| dag_rel_path=os.devnull, |
| bundle_info=FAKE_BUNDLE, |
| what=TaskInstance( |
| id="4d828a62-a417-4936-a7a6-2b3fabacecab", |
| task_id="b", |
| dag_id="c", |
| run_id="d", |
| try_number=1, |
| dag_version_id=uuid7(), |
| ), |
| client=client_with_ti_start, |
| target=subprocess_main, |
| ) |
| |
| rc = proc.wait() |
| |
| assert rc == 0, caplog.text |
| # Check that the log messages are write. We should expect stdout to apper right, and crucially, we should |
| # expect logs from the venv process to appear without extra "wrapping" |
| assert { |
| "logger": "task.stdout", |
| "event": "ok", |
| "log_level": "info", |
| "timestamp": mock.ANY, |
| } in caplog, caplog.text |
| assert { |
| "logger_name": "task", |
| "log_level": "info", |
| "event": "is connected", |
| "timestamp": mock.ANY, |
| } in caplog, caplog.text |
| |
| |
| _NOBODY_UID = 65534 |
| |
| |
| def _drop_root_if_needed(): |
| """Drop to a non-root UID so kernel dumpable checks actually apply (root/CAP_SYS_PTRACE bypasses them).""" |
| if os.getuid() == 0: |
| os.setuid(_NOBODY_UID) |
| |
| |
| @pytest.mark.skipif(sys.platform != "linux", reason="PR_SET_DUMPABLE is Linux-only") |
| def test_nondumpable_blocks_sibling_proc_read(): |
| """A sibling process (same non-root UID) cannot read /proc/<pid>/environ or /proc/<pid>/mem of a nondumpable process.""" |
| import multiprocessing |
| |
| ready = multiprocessing.Event() |
| done = multiprocessing.Event() |
| result_queue = multiprocessing.Queue() |
| |
| def target_fn(): |
| _drop_root_if_needed() |
| _make_process_nondumpable() |
| ready.set() |
| done.wait(timeout=10) |
| |
| def reader_fn(target_pid): |
| _drop_root_if_needed() |
| blocked = [] |
| for proc_file in ("environ", "mem"): |
| try: |
| open(f"/proc/{target_pid}/{proc_file}").read() |
| except PermissionError: |
| blocked.append(proc_file) |
| result_queue.put(blocked) |
| |
| target = multiprocessing.Process(target=target_fn) |
| target.start() |
| try: |
| assert ready.wait(timeout=5), "target process did not become ready" |
| reader = multiprocessing.Process(target=reader_fn, args=(target.pid,)) |
| reader.start() |
| reader.join(timeout=5) |
| blocked = result_queue.get(timeout=5) |
| assert "environ" in blocked, "Sibling was able to read nondumpable process's /proc/environ" |
| assert "mem" in blocked, "Sibling was able to read nondumpable process's /proc/mem" |
| finally: |
| done.set() |
| target.join(timeout=5) |
| if target.is_alive(): |
| target.kill() |
| |
| |
| @pytest.mark.skipif(sys.platform != "linux", reason="PR_SET_DUMPABLE is Linux-only") |
| def test_nondumpable_blocks_child_memory_read(): |
| """A forked child (same non-root UID) cannot read its nondumpable parent's /proc/<pid>/mem.""" |
| import multiprocessing |
| |
| result_queue = multiprocessing.Queue() |
| |
| def parent_fn(): |
| _drop_root_if_needed() |
| _make_process_nondumpable() |
| parent_pid = os.getpid() |
| child_pid = os.fork() |
| if child_pid == 0: |
| try: |
| open(f"/proc/{parent_pid}/mem").read() |
| except PermissionError: |
| os._exit(0) |
| else: |
| os._exit(1) |
| _, status = os.waitpid(child_pid, 0) |
| result_queue.put(os.WEXITSTATUS(status) if os.WIFEXITED(status) else -1) |
| |
| proc = multiprocessing.Process(target=parent_fn) |
| proc.start() |
| proc.join(timeout=10) |
| exit_code = result_queue.get(timeout=5) |
| assert exit_code == 0, "Child was able to read parent's /proc/mem — expected PermissionError" |
| |
| |
| @pytest.mark.skipif(sys.platform == "linux", reason="Test is for non-Linux platforms only") |
| def test_nondumpable_noop_on_non_linux(): |
| """On non-Linux, _make_process_nondumpable returns without error.""" |
| |
| _make_process_nondumpable() |