| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import json |
| import logging |
| import os |
| import shutil |
| import sys |
| from copy import copy |
| from tempfile import NamedTemporaryFile |
| from unittest import mock |
| from unittest.mock import ANY, MagicMock |
| from uuid import uuid4 |
| |
| import pendulum |
| import pytest |
| from kubernetes.client import models as k8s |
| from kubernetes.client.api_client import ApiClient |
| from kubernetes.client.rest import ApiException |
| from pytest import param |
| |
| from airflow.exceptions import AirflowException |
| from airflow.models import DAG, Connection, DagRun, TaskInstance |
| from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook |
| from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator |
| from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager |
| from airflow.utils import timezone |
| from airflow.utils.context import Context |
| from airflow.utils.types import DagRunType |
| from airflow.version import version as airflow_version |
| |
| HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesHook" |
| POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" |
| |
| |
| def create_context(task) -> Context: |
| dag = DAG(dag_id="dag") |
| tzinfo = pendulum.tz.timezone("Europe/Amsterdam") |
| execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) |
| dag_run = DagRun( |
| dag_id=dag.dag_id, |
| execution_date=execution_date, |
| run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), |
| ) |
| task_instance = TaskInstance(task=task) |
| task_instance.dag_run = dag_run |
| task_instance.dag_id = dag.dag_id |
| task_instance.xcom_push = mock.Mock() # type: ignore |
| return Context( |
| dag=dag, |
| run_id=dag_run.run_id, |
| task=task, |
| ti=task_instance, |
| task_instance=task_instance, |
| ) |
| |
| |
| @pytest.fixture(scope="session") |
| def kubeconfig_path(): |
| kubeconfig_path = os.environ.get("KUBECONFIG") |
| return kubeconfig_path if kubeconfig_path else os.path.expanduser("~/.kube/config") |
| |
| |
| @pytest.fixture |
| def test_label(request): |
| label = "".join(filter(str.isalnum, f"{request.node.cls.__name__}.{request.node.name}")).lower() |
| return label[-63:] |
| |
| |
| @pytest.fixture() |
| def mock_get_connection(): |
| with mock.patch(f"{HOOK_CLASS}.get_connection", return_value=Connection(conn_id="kubernetes_default")): |
| yield |
| |
| |
| @pytest.mark.execution_timeout(180) |
| @pytest.mark.usefixtures("mock_get_connection") |
| class TestKubernetesPodOperatorSystem: |
| @pytest.fixture(autouse=True) |
| def setup_tests(self, test_label): |
| self.api_client = ApiClient() |
| self.labels = {"test_label": test_label} |
| self.expected_pod = { |
| "apiVersion": "v1", |
| "kind": "Pod", |
| "metadata": { |
| "namespace": "default", |
| "name": ANY, |
| "annotations": {}, |
| "labels": { |
| "test_label": test_label, |
| "kubernetes_pod_operator": "True", |
| "airflow_version": airflow_version.replace("+", "-"), |
| "airflow_kpo_in_cluster": "False", |
| "run_id": "manual__2016-01-01T0100000100-da4d1ce7b", |
| "dag_id": "dag", |
| "task_id": ANY, |
| "try_number": "1", |
| }, |
| }, |
| "spec": { |
| "affinity": {}, |
| "containers": [ |
| { |
| "image": "ubuntu:16.04", |
| "args": ["echo 10"], |
| "command": ["bash", "-cx"], |
| "env": [], |
| "envFrom": [], |
| "name": "base", |
| "ports": [], |
| "volumeMounts": [], |
| } |
| ], |
| "hostNetwork": False, |
| "imagePullSecrets": [], |
| "initContainers": [], |
| "nodeSelector": {}, |
| "restartPolicy": "Never", |
| "securityContext": {}, |
| "tolerations": [], |
| "volumes": [], |
| }, |
| } |
| yield |
| hook = KubernetesHook(conn_id=None, in_cluster=False) |
| client = hook.core_v1_client |
| client.delete_collection_namespaced_pod(namespace="default", grace_period_seconds=0) |
| |
| def _get_labels_selector(self) -> str | None: |
| if not self.labels: |
| return None |
| return ",".join([f"{key}={value}" for key, value in enumerate(self.labels)]) |
| |
| def test_do_xcom_push_defaults_false(self, kubeconfig_path, mock_get_connection): |
| with NamedTemporaryFile(prefix="kube_config", suffix=".cfg") as f: |
| new_config_path = f.name |
| shutil.copy(kubeconfig_path, new_config_path) |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| config_file=new_config_path, |
| ) |
| assert not k.do_xcom_push |
| |
| def test_config_path_move(self, kubeconfig_path, mock_get_connection): |
| with NamedTemporaryFile(prefix="kube_config", suffix=".cfg") as f: |
| new_config_path = f.name |
| shutil.copy(kubeconfig_path, new_config_path) |
| |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| is_delete_operator_pod=False, |
| config_file=new_config_path, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| expected_pod = copy(self.expected_pod) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| assert actual_pod == expected_pod |
| |
| def test_working_pod(self, mock_get_connection): |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] |
| |
| def test_delete_operator_pod(self, mock_get_connection): |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| is_delete_operator_pod=True, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] |
| |
| def test_already_checked_on_success(self, mock_get_connection): |
| """ |
| When ``is_delete_operator_pod=False``, pod should have 'already_checked' |
| label, whether pod is successful or not. |
| """ |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| is_delete_operator_pod=False, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = k.find_pod("default", context, exclude_checked=False) |
| actual_pod = self.api_client.sanitize_for_serialization(actual_pod) |
| assert actual_pod["metadata"]["labels"]["already_checked"] == "True" |
| |
| def test_already_checked_on_failure(self, mock_get_connection): |
| """ |
| When ``is_delete_operator_pod=False``, pod should have 'already_checked' |
| label, whether pod is successful or not. |
| """ |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["lalala"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| is_delete_operator_pod=False, |
| ) |
| context = create_context(k) |
| with pytest.raises(AirflowException): |
| k.execute(context) |
| actual_pod = k.find_pod("default", context, exclude_checked=False) |
| actual_pod = self.api_client.sanitize_for_serialization(actual_pod) |
| status = next(iter(filter(lambda x: x["name"] == "base", actual_pod["status"]["containerStatuses"]))) |
| assert status["state"]["terminated"]["reason"] == "Error" |
| assert actual_pod["metadata"]["labels"]["already_checked"] == "True" |
| |
| def test_pod_hostnetwork(self, mock_get_connection): |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| hostnetwork=True, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["hostNetwork"] = True |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] |
| |
| def test_pod_dnspolicy(self, mock_get_connection): |
| dns_policy = "ClusterFirstWithHostNet" |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| hostnetwork=True, |
| dnspolicy=dns_policy, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["hostNetwork"] = True |
| self.expected_pod["spec"]["dnsPolicy"] = dns_policy |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] |
| |
| def test_pod_schedulername(self, mock_get_connection): |
| scheduler_name = "default-scheduler" |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| schedulername=scheduler_name, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["schedulerName"] = scheduler_name |
| assert self.expected_pod == actual_pod |
| |
| def test_pod_node_selector(self, mock_get_connection): |
| node_selector = {"beta.kubernetes.io/os": "linux"} |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| node_selector=node_selector, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["nodeSelector"] = node_selector |
| assert self.expected_pod == actual_pod |
| |
| def test_pod_resources(self, mock_get_connection): |
| resources = k8s.V1ResourceRequirements( |
| requests={"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"}, |
| limits={"memory": "64Mi", "cpu": 0.25, "nvidia.com/gpu": None, "ephemeral-storage": "2Gi"}, |
| ) |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| container_resources=resources, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["resources"] = { |
| "requests": {"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"}, |
| "limits": {"memory": "64Mi", "cpu": 0.25, "nvidia.com/gpu": None, "ephemeral-storage": "2Gi"}, |
| } |
| assert self.expected_pod == actual_pod |
| |
| @pytest.mark.parametrize( |
| "val", |
| [ |
| param( |
| k8s.V1Affinity( |
| node_affinity=k8s.V1NodeAffinity( |
| required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector( |
| node_selector_terms=[ |
| k8s.V1NodeSelectorTerm( |
| match_expressions=[ |
| k8s.V1NodeSelectorRequirement( |
| key="beta.kubernetes.io/os", |
| operator="In", |
| values=["linux"], |
| ) |
| ] |
| ) |
| ] |
| ) |
| ) |
| ), |
| id="current", |
| ), |
| param( |
| { |
| "nodeAffinity": { |
| "requiredDuringSchedulingIgnoredDuringExecution": { |
| "nodeSelectorTerms": [ |
| { |
| "matchExpressions": [ |
| { |
| "key": "beta.kubernetes.io/os", |
| "operator": "In", |
| "values": ["linux"], |
| } |
| ] |
| } |
| ] |
| } |
| } |
| }, |
| id="backcompat", |
| ), |
| ], |
| ) |
| def test_pod_affinity(self, val, mock_get_connection): |
| expected = { |
| "nodeAffinity": { |
| "requiredDuringSchedulingIgnoredDuringExecution": { |
| "nodeSelectorTerms": [ |
| { |
| "matchExpressions": [ |
| {"key": "beta.kubernetes.io/os", "operator": "In", "values": ["linux"]} |
| ] |
| } |
| ] |
| } |
| } |
| } |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| affinity=val, |
| ) |
| context = create_context(k) |
| k.execute(context=context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["affinity"] = expected |
| assert self.expected_pod == actual_pod |
| |
| def test_port(self, mock_get_connection): |
| port = k8s.V1ContainerPort( |
| name="http", |
| container_port=80, |
| ) |
| |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| ports=[port], |
| ) |
| context = create_context(k) |
| k.execute(context=context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["ports"] = [{"name": "http", "containerPort": 80}] |
| assert self.expected_pod == actual_pod |
| |
| def test_volume_mount(self, mock_get_connection): |
| with mock.patch.object(PodManager, "log") as mock_logger: |
| volume_mount = k8s.V1VolumeMount( |
| name="test-volume", mount_path="/tmp/test_volume", sub_path=None, read_only=False |
| ) |
| |
| volume = k8s.V1Volume( |
| name="test-volume", |
| persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name="test-volume"), |
| ) |
| |
| args = [ |
| 'echo "retrieved from mount" > /tmp/test_volume/test.txt && cat /tmp/test_volume/test.txt' |
| ] |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=args, |
| labels=self.labels, |
| volume_mounts=[volume_mount], |
| volumes=[volume], |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| ) |
| context = create_context(k) |
| k.execute(context=context) |
| mock_logger.info.assert_any_call("retrieved from mount") |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["args"] = args |
| self.expected_pod["spec"]["containers"][0]["volumeMounts"] = [ |
| {"name": "test-volume", "mountPath": "/tmp/test_volume", "readOnly": False} |
| ] |
| self.expected_pod["spec"]["volumes"] = [ |
| {"name": "test-volume", "persistentVolumeClaim": {"claimName": "test-volume"}} |
| ] |
| assert self.expected_pod == actual_pod |
| |
| @pytest.mark.parametrize("uid", [0, 1000]) |
| def test_run_as_user(self, uid, mock_get_connection): |
| security_context = {"runAsUser": uid} |
| name = str(uuid4()) |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| task_id=name, |
| name=name, |
| random_name_suffix=False, |
| is_delete_operator_pod=False, |
| in_cluster=False, |
| do_xcom_push=False, |
| security_context=security_context, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| pod = k.hook.core_v1_client.read_namespaced_pod( |
| name=name, |
| namespace="default", |
| ) |
| assert pod.to_dict()["spec"]["security_context"]["run_as_user"] == uid |
| |
| @pytest.mark.parametrize("gid", [0, 1000]) |
| def test_fs_group(self, gid, mock_get_connection): |
| security_context = {"fsGroup": gid} |
| name = str(uuid4()) |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| task_id=name, |
| name=name, |
| random_name_suffix=False, |
| is_delete_operator_pod=False, |
| in_cluster=False, |
| do_xcom_push=False, |
| security_context=security_context, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| pod = k.hook.core_v1_client.read_namespaced_pod( |
| name=name, |
| namespace="default", |
| ) |
| assert pod.to_dict()["spec"]["security_context"]["fs_group"] == gid |
| |
| def test_disable_privilege_escalation(self, mock_get_connection): |
| container_security_context = {"allowPrivilegeEscalation": False} |
| |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| container_security_context=container_security_context, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["securityContext"] = container_security_context |
| assert self.expected_pod == actual_pod |
| |
| def test_faulty_image(self, mock_get_connection): |
| bad_image_name = "foobar" |
| k = KubernetesPodOperator( |
| namespace="default", |
| image=bad_image_name, |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| startup_timeout_seconds=5, |
| ) |
| with pytest.raises(AirflowException): |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["image"] = bad_image_name |
| assert self.expected_pod == actual_pod |
| |
| def test_faulty_service_account(self, mock_get_connection): |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| startup_timeout_seconds=5, |
| service_account_name="foobar", |
| ) |
| context = create_context(k) |
| pod = k.build_pod_request_obj(context) |
| with pytest.raises(ApiException, match="error looking up service account default/foobar"): |
| k.get_or_create_pod(pod, context) |
| |
| def test_pod_failure(self, mock_get_connection): |
| """ |
| Tests that the task fails when a pod reports a failure |
| """ |
| bad_internal_command = ["foobar 10 "] |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=bad_internal_command, |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| ) |
| with pytest.raises(AirflowException): |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["args"] = bad_internal_command |
| assert self.expected_pod == actual_pod |
| |
| def test_xcom_push(self, test_label, mock_get_connection): |
| expected = {"test_label": test_label, "buzz": 2} |
| args = [f"echo '{json.dumps(expected)}' > /airflow/xcom/return.json"] |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=args, |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=True, |
| ) |
| context = create_context(k) |
| assert k.execute(context) == expected |
| |
| def test_env_vars(self, mock_get_connection): |
| # WHEN |
| env_vars = [ |
| k8s.V1EnvVar(name="ENV1", value="val1"), |
| k8s.V1EnvVar(name="ENV2", value="val2"), |
| k8s.V1EnvVar( |
| name="ENV3", |
| value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path="status.podIP")), |
| ), |
| ] |
| |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| env_vars=env_vars, |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| ) |
| # THEN |
| context = create_context(k) |
| actual_pod = self.api_client.sanitize_for_serialization(k.build_pod_request_obj(context)) |
| self.expected_pod["spec"]["containers"][0]["env"] = [ |
| {"name": "ENV1", "value": "val1"}, |
| {"name": "ENV2", "value": "val2"}, |
| {"name": "ENV3", "valueFrom": {"fieldRef": {"fieldPath": "status.podIP"}}}, |
| ] |
| assert self.expected_pod == actual_pod |
| |
| def test_pod_template_file_system(self, mock_get_connection): |
| """Note: this test requires that you have a namespace ``mem-example`` in your cluster.""" |
| fixture = sys.path[0] + "/tests/kubernetes/basic_pod.yaml" |
| k = KubernetesPodOperator( |
| task_id=str(uuid4()), |
| in_cluster=False, |
| labels=self.labels, |
| pod_template_file=fixture, |
| do_xcom_push=True, |
| ) |
| |
| context = create_context(k) |
| result = k.execute(context) |
| assert result is not None |
| assert result == {"hello": "world"} |
| |
| @pytest.mark.parametrize( |
| "env_vars", |
| [ |
| param([k8s.V1EnvVar(name="env_name", value="value")], id="current"), |
| param({"env_name": "value"}, id="backcompat"), # todo: remove? |
| ], |
| ) |
| def test_pod_template_file_with_overrides_system(self, env_vars, test_label, mock_get_connection): |
| fixture = sys.path[0] + "/tests/kubernetes/basic_pod.yaml" |
| k = KubernetesPodOperator( |
| task_id=str(uuid4()), |
| labels=self.labels, |
| env_vars=env_vars, |
| in_cluster=False, |
| pod_template_file=fixture, |
| do_xcom_push=True, |
| ) |
| |
| context = create_context(k) |
| result = k.execute(context) |
| assert result is not None |
| assert k.pod.metadata.labels == { |
| "test_label": test_label, |
| "airflow_version": mock.ANY, |
| "airflow_kpo_in_cluster": "False", |
| "dag_id": "dag", |
| "run_id": "manual__2016-01-01T0100000100-da4d1ce7b", |
| "kubernetes_pod_operator": "True", |
| "task_id": mock.ANY, |
| "try_number": "1", |
| } |
| assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] |
| assert result == {"hello": "world"} |
| |
| def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connection): |
| fixture = sys.path[0] + "/tests/kubernetes/basic_pod.yaml" |
| pod_spec = k8s.V1Pod( |
| metadata=k8s.V1ObjectMeta( |
| labels={"test_label": test_label, "fizz": "buzz"}, |
| ), |
| spec=k8s.V1PodSpec( |
| containers=[ |
| k8s.V1Container( |
| name="base", |
| env=[k8s.V1EnvVar(name="env_name", value="value")], |
| ) |
| ] |
| ), |
| ) |
| k = KubernetesPodOperator( |
| task_id=str(uuid4()), |
| labels=self.labels, |
| in_cluster=False, |
| pod_template_file=fixture, |
| full_pod_spec=pod_spec, |
| do_xcom_push=True, |
| ) |
| |
| context = create_context(k) |
| result = k.execute(context) |
| assert result is not None |
| assert k.pod.metadata.labels == { |
| "fizz": "buzz", |
| "test_label": test_label, |
| "airflow_version": mock.ANY, |
| "airflow_kpo_in_cluster": "False", |
| "dag_id": "dag", |
| "run_id": "manual__2016-01-01T0100000100-da4d1ce7b", |
| "kubernetes_pod_operator": "True", |
| "task_id": mock.ANY, |
| "try_number": "1", |
| } |
| assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] |
| assert result == {"hello": "world"} |
| |
| def test_full_pod_spec(self, test_label, mock_get_connection): |
| pod_spec = k8s.V1Pod( |
| metadata=k8s.V1ObjectMeta( |
| labels={"test_label": test_label, "fizz": "buzz"}, namespace="default", name="test-pod" |
| ), |
| spec=k8s.V1PodSpec( |
| containers=[ |
| k8s.V1Container( |
| name="base", |
| image="perl", |
| command=["/bin/bash"], |
| args=["-c", 'echo {\\"hello\\" : \\"world\\"} | cat > /airflow/xcom/return.json'], |
| env=[k8s.V1EnvVar(name="env_name", value="value")], |
| ) |
| ], |
| restart_policy="Never", |
| ), |
| ) |
| k = KubernetesPodOperator( |
| task_id=str(uuid4()), |
| in_cluster=False, |
| labels=self.labels, |
| full_pod_spec=pod_spec, |
| do_xcom_push=True, |
| is_delete_operator_pod=False, |
| startup_timeout_seconds=30, |
| ) |
| |
| context = create_context(k) |
| result = k.execute(context) |
| assert result is not None |
| assert k.pod.metadata.labels == { |
| "fizz": "buzz", |
| "test_label": test_label, |
| "airflow_version": mock.ANY, |
| "airflow_kpo_in_cluster": "False", |
| "dag_id": "dag", |
| "run_id": "manual__2016-01-01T0100000100-da4d1ce7b", |
| "kubernetes_pod_operator": "True", |
| "task_id": mock.ANY, |
| "try_number": "1", |
| } |
| assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] |
| assert result == {"hello": "world"} |
| |
| def test_init_container(self, mock_get_connection): |
| # GIVEN |
| volume_mounts = [ |
| k8s.V1VolumeMount(mount_path="/etc/foo", name="test-volume", sub_path=None, read_only=True) |
| ] |
| |
| init_environments = [ |
| k8s.V1EnvVar(name="key1", value="value1"), |
| k8s.V1EnvVar(name="key2", value="value2"), |
| ] |
| |
| init_container = k8s.V1Container( |
| name="init-container", |
| image="ubuntu:16.04", |
| env=init_environments, |
| volume_mounts=volume_mounts, |
| command=["bash", "-cx"], |
| args=["echo 10"], |
| ) |
| |
| volume = k8s.V1Volume( |
| name="test-volume", |
| persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name="test-volume"), |
| ) |
| expected_init_container = { |
| "name": "init-container", |
| "image": "ubuntu:16.04", |
| "command": ["bash", "-cx"], |
| "args": ["echo 10"], |
| "env": [{"name": "key1", "value": "value1"}, {"name": "key2", "value": "value2"}], |
| "volumeMounts": [{"mountPath": "/etc/foo", "name": "test-volume", "readOnly": True}], |
| } |
| |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| volumes=[volume], |
| init_containers=[init_container], |
| in_cluster=False, |
| do_xcom_push=False, |
| ) |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["initContainers"] = [expected_init_container] |
| self.expected_pod["spec"]["volumes"] = [ |
| {"name": "test-volume", "persistentVolumeClaim": {"claimName": "test-volume"}} |
| ] |
| assert self.expected_pod == actual_pod |
| |
| @mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start") |
| @mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom") |
| @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") |
| @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) |
| @mock.patch(HOOK_CLASS) |
| def test_pod_template_file( |
| self, |
| hook_mock, |
| await_pod_completion_mock, |
| extract_xcom_mock, |
| await_xcom_sidecar_container_start_mock, |
| caplog, |
| test_label, |
| ): |
| # todo: This isn't really a system test |
| await_xcom_sidecar_container_start_mock.return_value = None |
| hook_mock.return_value.is_in_cluster = False |
| hook_mock.return_value.get_connection.return_value = Connection(conn_id="kubernetes_default") |
| extract_xcom_mock.return_value = "{}" |
| path = sys.path[0] + "/tests/kubernetes/pod.yaml" |
| k = KubernetesPodOperator( |
| task_id=str(uuid4()), |
| labels=self.labels, |
| random_name_suffix=False, |
| pod_template_file=path, |
| do_xcom_push=True, |
| ) |
| pod_mock = MagicMock() |
| pod_mock.status.phase = "Succeeded" |
| await_pod_completion_mock.return_value = pod_mock |
| context = create_context(k) |
| |
| # I'm not really sure what the point is of this assert |
| with caplog.at_level(logging.DEBUG, logger="airflow.task.operators"): |
| k.execute(context) |
| expected_lines = [ |
| "Starting pod:", |
| "api_version: v1", |
| "kind: Pod", |
| "metadata:", |
| " annotations: {}", |
| " cluster_name: null", |
| " creation_timestamp: null", |
| " deletion_grace_period_seconds: null", |
| ] |
| actual = [x.getMessage() for x in caplog.records if x.msg == "Starting pod:\n%s"][0].splitlines() |
| assert actual[: len(expected_lines)] == expected_lines |
| |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| expected_dict = { |
| "apiVersion": "v1", |
| "kind": "Pod", |
| "metadata": { |
| "annotations": {}, |
| "labels": { |
| "test_label": test_label, |
| "airflow_kpo_in_cluster": "False", |
| "dag_id": "dag", |
| "run_id": "manual__2016-01-01T0100000100-da4d1ce7b", |
| "kubernetes_pod_operator": "True", |
| "task_id": mock.ANY, |
| "try_number": "1", |
| }, |
| "name": "memory-demo", |
| "namespace": "mem-example", |
| }, |
| "spec": { |
| "affinity": {}, |
| "containers": [ |
| { |
| "args": ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"], |
| "command": ["stress"], |
| "env": [], |
| "envFrom": [], |
| "image": "ghcr.io/apache/airflow-stress:1.0.4-2021.07.04", |
| "name": "base", |
| "ports": [], |
| "resources": {"limits": {"memory": "200Mi"}, "requests": {"memory": "100Mi"}}, |
| "volumeMounts": [{"mountPath": "/airflow/xcom", "name": "xcom"}], |
| }, |
| { |
| "command": ["sh", "-c", 'trap "exit 0" INT; while true; do sleep 1; done;'], |
| "image": "alpine", |
| "name": "airflow-xcom-sidecar", |
| "resources": { |
| "requests": {"cpu": "1m", "memory": "10Mi"}, |
| }, |
| "volumeMounts": [{"mountPath": "/airflow/xcom", "name": "xcom"}], |
| }, |
| ], |
| "hostNetwork": False, |
| "imagePullSecrets": [], |
| "initContainers": [], |
| "nodeSelector": {}, |
| "restartPolicy": "Never", |
| "securityContext": {}, |
| "tolerations": [], |
| "volumes": [{"emptyDir": {}, "name": "xcom"}], |
| }, |
| } |
| version = actual_pod["metadata"]["labels"]["airflow_version"] |
| assert version.startswith(airflow_version) |
| del actual_pod["metadata"]["labels"]["airflow_version"] |
| assert expected_dict == actual_pod |
| |
| @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") |
| @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) |
| @mock.patch(HOOK_CLASS) |
| def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock): |
| """ |
| Test ability to assign priorityClassName to pod |
| |
| todo: This isn't really a system test |
| """ |
| hook_mock.return_value.is_in_cluster = False |
| hook_mock.return_value.get_connection.return_value = Connection(conn_id="kubernetes_default") |
| |
| priority_class_name = "medium-test" |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| priority_class_name=priority_class_name, |
| ) |
| |
| pod_mock = MagicMock() |
| pod_mock.status.phase = "Succeeded" |
| await_pod_completion_mock.return_value = pod_mock |
| context = create_context(k) |
| k.execute(context) |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["priorityClassName"] = priority_class_name |
| assert self.expected_pod == actual_pod |
| |
| def test_pod_name(self, mock_get_connection): |
| pod_name_too_long = "a" * 221 |
| with pytest.raises(AirflowException): |
| KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| name=pod_name_too_long, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| ) |
| |
| def test_on_kill(self, mock_get_connection): |
| hook = KubernetesHook(conn_id=None, in_cluster=False) |
| client = hook.core_v1_client |
| name = "test" |
| namespace = "default" |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["sleep 1000"], |
| labels=self.labels, |
| name=name, |
| task_id=name, |
| in_cluster=False, |
| do_xcom_push=False, |
| get_logs=False, |
| termination_grace_period=0, |
| ) |
| context = create_context(k) |
| |
| class ShortCircuitException(Exception): |
| pass |
| |
| # use this mock to short circuit and NOT wait for container completion |
| with mock.patch.object( |
| k.pod_manager, "await_container_completion", side_effect=ShortCircuitException() |
| ): |
| # cleanup will be upset since the pod should not be completed.. so skip it |
| with mock.patch.object(k, "cleanup"): |
| with pytest.raises(ShortCircuitException): |
| k.execute(context) |
| |
| # when we get here, the pod should still be running |
| name = k.pod.metadata.name |
| pod = client.read_namespaced_pod(name=name, namespace=namespace) |
| assert pod.status.phase == "Running" |
| k.on_kill() |
| with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'): |
| client.read_namespaced_pod(name=name, namespace=namespace) |
| |
| def test_reattach_failing_pod_once(self, mock_get_connection): |
| hook = KubernetesHook(conn_id=None, in_cluster=False) |
| client = hook.core_v1_client |
| name = "test" |
| namespace = "default" |
| |
| def get_op(): |
| return KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["exit 1"], |
| labels=self.labels, |
| name="test", |
| task_id=name, |
| in_cluster=False, |
| do_xcom_push=False, |
| is_delete_operator_pod=False, |
| termination_grace_period=0, |
| ) |
| |
| k = get_op() |
| |
| context = create_context(k) |
| |
| # launch pod |
| with mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") as await_pod_completion_mock: |
| pod_mock = MagicMock() |
| |
| pod_mock.status.phase = "Succeeded" |
| await_pod_completion_mock.return_value = pod_mock |
| |
| # we want to simulate that there was a worker failure and the airflow operator process |
| # was killed without running the cleanup process. in this case the pod will not be marked as |
| # already checked |
| k.cleanup = MagicMock() |
| |
| k.execute(context) |
| name = k.pod.metadata.name |
| pod = client.read_namespaced_pod(name=name, namespace=namespace) |
| while pod.status.phase != "Failed": |
| pod = client.read_namespaced_pod(name=name, namespace=namespace) |
| assert "already_checked" not in pod.metadata.labels |
| |
| # create a new version of the same operator instance to remove the monkey patching in first |
| # part of the test |
| k = get_op() |
| |
| # `create_pod` should not be called because there's a pod there it should find |
| # should use the found pod and patch as "already_checked" (in failure block) |
| with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock: |
| with pytest.raises(AirflowException): |
| k.execute(context) |
| pod = client.read_namespaced_pod(name=name, namespace=namespace) |
| assert pod.metadata.labels["already_checked"] == "True" |
| create_mock.assert_not_called() |
| |
| # recreate op just to ensure we're not relying on any statefulness |
| k = get_op() |
| |
| # `create_pod` should be called because though there's still a pod to be found, |
| # it will be `already_checked` |
| with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock: |
| with pytest.raises(AirflowException): |
| k.execute(context) |
| create_mock.assert_called_once() |
| |
| def test_using_resources(self, mock_get_connection): |
| exception_message = ( |
| "Specifying resources for the launched pod with 'resources' is deprecated. " |
| "Use 'container_resources' instead." |
| ) |
| with pytest.raises(AirflowException, match=exception_message): |
| resources = k8s.V1ResourceRequirements( |
| requests={"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"}, |
| limits={"memory": "64Mi", "cpu": 0.25, "nvidia.com/gpu": None, "ephemeral-storage": "2Gi"}, |
| ) |
| KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| resources=resources, |
| ) |
| |
| def test_changing_base_container_name_with_get_logs(self, mock_get_connection): |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| get_logs=True, |
| base_container_name="apple-sauce", |
| ) |
| assert k.base_container_name == "apple-sauce" |
| context = create_context(k) |
| with mock.patch.object( |
| k.pod_manager, "fetch_container_logs", wraps=k.pod_manager.fetch_container_logs |
| ) as mock_fetch_container_logs: |
| k.execute(context) |
| |
| assert mock_fetch_container_logs.call_args[1]["container_name"] == "apple-sauce" |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| |
| def test_changing_base_container_name_no_logs(self, mock_get_connection): |
| """ |
| This test checks BOTH a modified base container name AND the get_logs=False flow, |
| and as a result, also checks that the flow works with fast containers |
| See https://github.com/apache/airflow/issues/26796 |
| """ |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["echo 10"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| get_logs=False, |
| base_container_name="apple-sauce", |
| ) |
| assert k.base_container_name == "apple-sauce" |
| context = create_context(k) |
| with mock.patch.object( |
| k.pod_manager, "await_container_completion", wraps=k.pod_manager.await_container_completion |
| ) as mock_await_container_completion: |
| k.execute(context) |
| |
| assert mock_await_container_completion.call_args[1]["container_name"] == "apple-sauce" |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| |
| def test_changing_base_container_name_no_logs_long(self, mock_get_connection): |
| """ |
| Similar to test_changing_base_container_name_no_logs, but ensures that |
| pods running longer than 1 second work too. |
| See https://github.com/apache/airflow/issues/26796 |
| """ |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["bash", "-cx"], |
| arguments=["sleep 3"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| get_logs=False, |
| base_container_name="apple-sauce", |
| ) |
| assert k.base_container_name == "apple-sauce" |
| context = create_context(k) |
| with mock.patch.object( |
| k.pod_manager, "await_container_completion", wraps=k.pod_manager.await_container_completion |
| ) as mock_await_container_completion: |
| k.execute(context) |
| |
| assert mock_await_container_completion.call_args[1]["container_name"] == "apple-sauce" |
| actual_pod = self.api_client.sanitize_for_serialization(k.pod) |
| self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" |
| self.expected_pod["spec"]["containers"][0]["args"] = ["sleep 3"] |
| assert self.expected_pod["spec"] == actual_pod["spec"] |
| |
| def test_changing_base_container_name_failure(self, mock_get_connection): |
| k = KubernetesPodOperator( |
| namespace="default", |
| image="ubuntu:16.04", |
| cmds=["exit"], |
| arguments=["1"], |
| labels=self.labels, |
| task_id=str(uuid4()), |
| in_cluster=False, |
| do_xcom_push=False, |
| base_container_name="apple-sauce", |
| ) |
| assert k.base_container_name == "apple-sauce" |
| context = create_context(k) |
| |
| class ShortCircuitException(Exception): |
| pass |
| |
| with mock.patch( |
| "airflow.providers.cncf.kubernetes.operators.pod.get_container_termination_message", |
| side_effect=ShortCircuitException(), |
| ) as mock_get_container_termination_message: |
| with pytest.raises(ShortCircuitException): |
| k.execute(context) |
| |
| assert mock_get_container_termination_message.call_args[0][1] == "apple-sauce" |
| |
| def test_base_container_name_init_precedence(self, mock_get_connection): |
| assert ( |
| KubernetesPodOperator(base_container_name="apple-sauce", task_id=str(uuid4())).base_container_name |
| == "apple-sauce" |
| ) |
| assert ( |
| KubernetesPodOperator(task_id=str(uuid4())).base_container_name |
| == KubernetesPodOperator.BASE_CONTAINER_NAME |
| ) |
| |
| class MyK8SPodOperator(KubernetesPodOperator): |
| BASE_CONTAINER_NAME = "tomato-sauce" |
| |
| assert ( |
| MyK8SPodOperator(base_container_name="apple-sauce", task_id=str(uuid4())).base_container_name |
| == "apple-sauce" |
| ) |
| assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce" |
| |
| |
| def test_hide_sensitive_field_in_templated_fields_on_error(caplog, monkeypatch): |
| logger = logging.getLogger("airflow.task") |
| monkeypatch.setattr(logger, "propagate", True) |
| |
| class Var: |
| def __getattr__(self, name): |
| raise KeyError(name) |
| |
| context = { |
| "password": "secretpassword", |
| "var": Var(), |
| } |
| from airflow.providers.cncf.kubernetes.operators.pod import ( |
| KubernetesPodOperator, |
| ) |
| |
| task = KubernetesPodOperator( |
| task_id="dry_run_demo", |
| name="hello-dry-run", |
| image="python:3.8-slim-buster", |
| cmds=["printenv"], |
| env_vars={ |
| "password": "{{ password }}", |
| "VAR2": "{{ var.value.nonexisting}}", |
| }, |
| ) |
| with pytest.raises(KeyError): |
| task.render_template_fields(context=context) |
| assert "password" in caplog.text |
| assert "secretpassword" not in caplog.text |