| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import os |
| from unittest import mock |
| |
| import pytest |
| from sqlalchemy.orm import make_transient |
| |
| from airflow.configuration import TEST_DAGS_FOLDER |
| from airflow.models.renderedtifields import RenderedTaskInstanceFields, RenderedTaskInstanceFields as RTIF |
| from airflow.operators.bash import BashOperator |
| from airflow.utils.session import create_session |
| from airflow.version import version |
| from tests.models import DEFAULT_DATE |
| |
| pytestmark = pytest.mark.db_test |
| |
| |
| @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) |
| @mock.patch("airflow.settings.pod_mutation_hook") |
| def test_render_k8s_pod_yaml(pod_mutation_hook, create_task_instance): |
| ti = create_task_instance( |
| dag_id="test_render_k8s_pod_yaml", |
| run_id="test_run_id", |
| task_id="op1", |
| execution_date=DEFAULT_DATE, |
| ) |
| |
| expected_pod_spec = { |
| "metadata": { |
| "annotations": { |
| "dag_id": "test_render_k8s_pod_yaml", |
| "run_id": "test_run_id", |
| "task_id": "op1", |
| "try_number": "0", |
| }, |
| "labels": { |
| "airflow-worker": "0", |
| "airflow_version": version, |
| "dag_id": "test_render_k8s_pod_yaml", |
| "run_id": "test_run_id", |
| "kubernetes_executor": "True", |
| "task_id": "op1", |
| "try_number": "0", |
| }, |
| "name": mock.ANY, |
| "namespace": "default", |
| }, |
| "spec": { |
| "containers": [ |
| { |
| "args": [ |
| "airflow", |
| "tasks", |
| "run", |
| "test_render_k8s_pod_yaml", |
| "op1", |
| "test_run_id", |
| "--subdir", |
| __file__, |
| ], |
| "name": "base", |
| "env": [{"name": "AIRFLOW_IS_K8S_EXECUTOR_POD", "value": "True"}], |
| } |
| ] |
| }, |
| } |
| |
| assert ti.render_k8s_pod_yaml() == expected_pod_spec |
| pod_mutation_hook.assert_called_once_with(mock.ANY) |
| |
| |
| @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) |
| @mock.patch.object(RenderedTaskInstanceFields, "get_k8s_pod_yaml") |
| @mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml") |
| def test_get_rendered_k8s_spec(render_k8s_pod_yaml, rtif_get_k8s_pod_yaml, create_task_instance): |
| # Create new TI for the same Task |
| ti = create_task_instance() |
| |
| mock.patch.object(ti, "render_k8s_pod_yaml", autospec=True) |
| |
| fake_spec = {"ermagawds": "pods"} |
| |
| session = mock.Mock() |
| |
| rtif_get_k8s_pod_yaml.return_value = fake_spec |
| assert ti.get_rendered_k8s_spec(session) == fake_spec |
| |
| rtif_get_k8s_pod_yaml.assert_called_once_with(ti, session=session) |
| render_k8s_pod_yaml.assert_not_called() |
| |
| # Now test that when we _dont_ find it in the DB, it calls render_k8s_pod_yaml |
| rtif_get_k8s_pod_yaml.return_value = None |
| render_k8s_pod_yaml.return_value = fake_spec |
| |
| assert ti.get_rendered_k8s_spec(session) == fake_spec |
| |
| render_k8s_pod_yaml.assert_called_once() |
| |
| |
| @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) |
| @mock.patch("airflow.utils.log.secrets_masker.redact", autospec=True, side_effect=lambda d, _=None: d) |
| @mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml") |
| def test_get_k8s_pod_yaml(render_k8s_pod_yaml, redact, dag_maker): |
| """ |
| Test that k8s_pod_yaml is rendered correctly, stored in the Database, |
| and are correctly fetched using RTIF.get_k8s_pod_yaml |
| """ |
| with dag_maker("test_get_k8s_pod_yaml") as dag: |
| task = BashOperator(task_id="test", bash_command="echo hi") |
| dr = dag_maker.create_dagrun() |
| dag.fileloc = TEST_DAGS_FOLDER + "/test_get_k8s_pod_yaml.py" |
| |
| ti = dr.task_instances[0] |
| ti.task = task |
| |
| render_k8s_pod_yaml.return_value = {"I'm a": "pod"} |
| |
| rtif = RTIF(ti=ti) |
| |
| assert ti.dag_id == rtif.dag_id |
| assert ti.task_id == rtif.task_id |
| assert ti.run_id == rtif.run_id |
| |
| expected_pod_yaml = {"I'm a": "pod"} |
| |
| assert rtif.k8s_pod_yaml == render_k8s_pod_yaml.return_value |
| # K8s pod spec dict was passed to redact |
| redact.assert_any_call(rtif.k8s_pod_yaml) |
| |
| with create_session() as session: |
| session.add(rtif) |
| session.flush() |
| |
| assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session) |
| make_transient(ti) |
| # "Delete" it from the DB |
| session.rollback() |
| |
| # Test the else part of get_k8s_pod_yaml |
| # i.e. for the TIs that are not stored in RTIF table |
| # Fetching them will return None |
| assert RTIF.get_k8s_pod_yaml(ti=ti, session=session) is None |