| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import json |
| import os |
| import tempfile |
| from asyncio import Future |
| from unittest import mock |
| from unittest.mock import MagicMock, patch |
| |
| import kubernetes |
| import pytest |
| from kubernetes.client.rest import ApiException |
| from kubernetes.config import ConfigException |
| from sqlalchemy.orm import make_transient |
| |
| from airflow.exceptions import AirflowException, AirflowNotFoundException |
| from airflow.hooks.base import BaseHook |
| from airflow.models import Connection |
| from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook |
| from airflow.utils import db |
| from airflow.utils.db import merge_conn |
| from tests.test_utils.db import clear_db_connections |
| from tests.test_utils.providers import get_provider_min_airflow_version |
| |
| pytestmark = pytest.mark.db_test |
| |
| |
| KUBE_CONFIG_PATH = os.getenv("KUBECONFIG", "~/.kube/config") |
| HOOK_MODULE = "airflow.providers.cncf.kubernetes.hooks.kubernetes" |
| |
| CONN_ID = "kubernetes-test-id" |
| ASYNC_CONFIG_PATH = "/files/path/to/config/file" |
| POD_NAME = "test-pod" |
| NAMESPACE = "test-namespace" |
| JOB_NAME = "test-job" |
| POLL_INTERVAL = 100 |
| |
| |
| class DeprecationRemovalRequired(AirflowException): ... |
| |
| |
| DEFAULT_CONN_ID = "kubernetes_default" |
| |
| |
| @pytest.fixture |
| def remove_default_conn(session): |
| before_conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).one_or_none() |
| if before_conn: |
| session.delete(before_conn) |
| session.commit() |
| yield |
| if before_conn: |
| make_transient(before_conn) |
| session.add(before_conn) |
| session.commit() |
| |
| |
| class TestKubernetesHook: |
| @classmethod |
| def setup_class(cls) -> None: |
| for conn_id, extra in [ |
| ("in_cluster", {"in_cluster": True}), |
| ("in_cluster_empty", {"in_cluster": ""}), |
| ("kube_config", {"kube_config": '{"test": "kube"}'}), |
| ("kube_config_path", {"kube_config_path": "path/to/file"}), |
| ("kube_config_empty", {"kube_config": ""}), |
| ("kube_config_path_empty", {"kube_config_path": ""}), |
| ("kube_config_empty", {"kube_config": ""}), |
| ("kube_config_path_empty", {"kube_config_path": ""}), |
| ("context_empty", {"cluster_context": ""}), |
| ("context", {"cluster_context": "my-context"}), |
| ("with_namespace", {"namespace": "mock_namespace"}), |
| ("default_kube_config", {}), |
| ("disable_verify_ssl", {"disable_verify_ssl": True}), |
| ("disable_verify_ssl_empty", {"disable_verify_ssl": ""}), |
| ("disable_tcp_keepalive", {"disable_tcp_keepalive": True}), |
| ("disable_tcp_keepalive_empty", {"disable_tcp_keepalive": ""}), |
| ("sidecar_container_image", {"xcom_sidecar_container_image": "private.repo.com/alpine:3.16"}), |
| ("sidecar_container_image_empty", {"xcom_sidecar_container_image": ""}), |
| ( |
| "sidecar_container_resources", |
| { |
| "xcom_sidecar_container_resources": json.dumps( |
| { |
| "requests": {"cpu": "1m", "memory": "10Mi"}, |
| "limits": {"cpu": "1m", "memory": "50Mi"}, |
| } |
| ), |
| }, |
| ), |
| ("sidecar_container_resources_empty", {"xcom_sidecar_container_resources": ""}), |
| ]: |
| db.merge_conn(Connection(conn_type="kubernetes", conn_id=conn_id, extra=json.dumps(extra))) |
| |
| @classmethod |
| def teardown_class(cls) -> None: |
| clear_db_connections() |
| |
| @pytest.mark.parametrize( |
| "in_cluster_param, conn_id, in_cluster_called", |
| ( |
| (True, None, True), |
| (None, None, False), |
| (False, None, False), |
| (None, "in_cluster", True), |
| (True, "in_cluster", True), |
| (False, "in_cluster", False), |
| (None, "in_cluster_empty", False), |
| (True, "in_cluster_empty", True), |
| (False, "in_cluster_empty", False), |
| ), |
| ) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch("kubernetes.config.incluster_config.InClusterConfigLoader") |
| @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client") |
| def test_in_cluster_connection( |
| self, |
| mock_get_default_client, |
| mock_in_cluster_loader, |
| mock_merger, |
| mock_loader, |
| in_cluster_param, |
| conn_id, |
| in_cluster_called, |
| ): |
| """ |
| Verifies whether in_cluster is called depending on combination of hook param and connection extra. |
| Hook param should beat extra. |
| """ |
| kubernetes_hook = KubernetesHook(conn_id=conn_id, in_cluster=in_cluster_param) |
| mock_get_default_client.return_value = kubernetes.client.api_client.ApiClient() |
| api_conn = kubernetes_hook.get_conn() |
| if in_cluster_called: |
| mock_in_cluster_loader.assert_called_once() |
| mock_merger.assert_not_called() |
| mock_loader.assert_not_called() |
| else: |
| mock_get_default_client.assert_called() |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| if not mock_get_default_client.called: |
| # get_default_client is mocked, so only check is_in_cluster if it isn't called |
| assert kubernetes_hook.is_in_cluster is in_cluster_called |
| |
| @pytest.mark.parametrize("in_cluster_fails", [True, False]) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch("kubernetes.config.incluster_config.InClusterConfigLoader") |
| def test_get_default_client( |
| self, |
| mock_incluster, |
| mock_merger, |
| mock_loader, |
| in_cluster_fails, |
| ): |
| """ |
| Verifies the behavior of the ``_get_default_client`` function. It should try the "in cluster" |
| loader first but if that fails, try to use the default kubeconfig file. |
| """ |
| if in_cluster_fails: |
| mock_incluster.side_effect = ConfigException("any") |
| kubernetes_hook = KubernetesHook() |
| api_conn = kubernetes_hook._get_default_client() |
| if in_cluster_fails: |
| mock_incluster.assert_called_once() |
| mock_merger.assert_called_once_with(KUBE_CONFIG_PATH) |
| mock_loader.assert_called_once() |
| assert kubernetes_hook.is_in_cluster is False |
| else: |
| mock_incluster.assert_called_once() |
| mock_merger.assert_not_called() |
| mock_loader.assert_not_called() |
| assert kubernetes_hook.is_in_cluster is True |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| |
| @pytest.mark.parametrize( |
| "disable_verify_ssl, conn_id, disable_called", |
| ( |
| (True, None, True), |
| (None, None, False), |
| (False, None, False), |
| (None, "disable_verify_ssl", True), |
| (True, "disable_verify_ssl", True), |
| (False, "disable_verify_ssl", False), |
| (None, "disable_verify_ssl_empty", False), |
| (True, "disable_verify_ssl_empty", True), |
| (False, "disable_verify_ssl_empty", False), |
| ), |
| ) |
| @patch("kubernetes.config.incluster_config.InClusterConfigLoader", new=MagicMock()) |
| @patch(f"{HOOK_MODULE}._disable_verify_ssl") |
| def test_disable_verify_ssl( |
| self, |
| mock_disable, |
| disable_verify_ssl, |
| conn_id, |
| disable_called, |
| ): |
| """ |
| Verifies whether disable verify ssl is called depending on combination of hook param and |
| connection extra. Hook param should beat extra. |
| """ |
| kubernetes_hook = KubernetesHook(conn_id=conn_id, disable_verify_ssl=disable_verify_ssl) |
| api_conn = kubernetes_hook.get_conn() |
| assert mock_disable.called is disable_called |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| |
| @pytest.mark.parametrize( |
| "disable_tcp_keepalive, conn_id, expected", |
| ( |
| (True, None, False), |
| (None, None, True), |
| (False, None, True), |
| (None, "disable_tcp_keepalive", False), |
| (True, "disable_tcp_keepalive", False), |
| (False, "disable_tcp_keepalive", True), |
| (None, "disable_tcp_keepalive_empty", True), |
| (True, "disable_tcp_keepalive_empty", False), |
| (False, "disable_tcp_keepalive_empty", True), |
| ), |
| ) |
| @patch("kubernetes.config.incluster_config.InClusterConfigLoader", new=MagicMock()) |
| @patch(f"{HOOK_MODULE}._enable_tcp_keepalive") |
| def test_disable_tcp_keepalive( |
| self, |
| mock_enable, |
| disable_tcp_keepalive, |
| conn_id, |
| expected, |
| ): |
| """ |
| Verifies whether enable tcp keepalive is called depending on combination of hook |
| param and connection extra. Hook param should beat extra. |
| """ |
| kubernetes_hook = KubernetesHook(conn_id=conn_id, disable_tcp_keepalive=disable_tcp_keepalive) |
| api_conn = kubernetes_hook.get_conn() |
| assert mock_enable.called is expected |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| |
| @pytest.mark.parametrize( |
| "config_path_param, conn_id, call_path", |
| ( |
| (None, None, KUBE_CONFIG_PATH), |
| ("/my/path/override", None, "/my/path/override"), |
| (None, "kube_config_path", "path/to/file"), |
| ("/my/path/override", "kube_config_path", "/my/path/override"), |
| (None, "kube_config_path_empty", KUBE_CONFIG_PATH), |
| ("/my/path/override", "kube_config_path_empty", "/my/path/override"), |
| ), |
| ) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_kube_config_path( |
| self, mock_kube_config_merger, mock_kube_config_loader, config_path_param, conn_id, call_path |
| ): |
| """ |
| Verifies kube config path depending on combination of hook param and connection extra. |
| Hook param should beat extra. |
| """ |
| kubernetes_hook = KubernetesHook(conn_id=conn_id, config_file=config_path_param) |
| api_conn = kubernetes_hook.get_conn() |
| mock_kube_config_merger.assert_called_once_with(call_path) |
| mock_kube_config_loader.assert_called_once() |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| |
| @pytest.mark.parametrize( |
| "conn_id, has_config", |
| ( |
| (None, False), |
| ("kube_config", True), |
| ("kube_config_empty", False), |
| ), |
| ) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch.object(tempfile, "NamedTemporaryFile") |
| def test_kube_config_connection( |
| self, mock_tempfile, mock_kube_config_merger, mock_kube_config_loader, conn_id, has_config |
| ): |
| """ |
| Verifies whether temporary kube config file is created. |
| """ |
| mock_tempfile.return_value.__enter__.return_value.name = "fake-temp-file" |
| mock_kube_config_merger.return_value.config = {"fake_config": "value"} |
| kubernetes_hook = KubernetesHook(conn_id=conn_id) |
| api_conn = kubernetes_hook.get_conn() |
| if has_config: |
| mock_tempfile.is_called_once() |
| mock_kube_config_loader.assert_called_once() |
| mock_kube_config_merger.assert_called_once_with("fake-temp-file") |
| else: |
| mock_tempfile.assert_not_called() |
| mock_kube_config_loader.assert_called_once() |
| mock_kube_config_merger.assert_called_once_with(KUBE_CONFIG_PATH) |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| |
| @pytest.mark.parametrize( |
| "context_param, conn_id, expected_context", |
| ( |
| ("param-context", None, "param-context"), |
| (None, None, None), |
| ("param-context", "context", "param-context"), |
| (None, "context", "my-context"), |
| ("param-context", "context_empty", "param-context"), |
| (None, "context_empty", None), |
| ), |
| ) |
| @patch("kubernetes.config.load_kube_config") |
| def test_cluster_context(self, mock_load_kube_config, context_param, conn_id, expected_context): |
| """ |
| Verifies cluster context depending on combination of hook param and connection extra. |
| Hook param should beat extra. |
| """ |
| kubernetes_hook = KubernetesHook(conn_id=conn_id, cluster_context=context_param) |
| kubernetes_hook.get_conn() |
| mock_load_kube_config.assert_called_with(client_configuration=None, context=expected_context) |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch("kubernetes.config.kube_config.KUBE_CONFIG_DEFAULT_LOCATION", "/mock/config") |
| def test_default_kube_config_connection(self, mock_kube_config_merger, mock_kube_config_loader): |
| kubernetes_hook = KubernetesHook(conn_id="default_kube_config") |
| api_conn = kubernetes_hook.get_conn() |
| mock_kube_config_merger.assert_called_once_with("/mock/config") |
| mock_kube_config_loader.assert_called_once() |
| assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) |
| |
| @pytest.mark.parametrize( |
| "conn_id, expected", |
| ( |
| pytest.param(None, None, id="no-conn-id"), |
| pytest.param("with_namespace", "mock_namespace", id="conn-with-namespace"), |
| pytest.param("default_kube_config", None, id="conn-without-namespace"), |
| ), |
| ) |
| def test_get_namespace(self, conn_id, expected): |
| hook = KubernetesHook(conn_id=conn_id) |
| assert hook.get_namespace() == expected |
| if get_provider_min_airflow_version("apache-airflow-providers-cncf-kubernetes") >= (6, 0): |
| raise DeprecationRemovalRequired( |
| "You must update get_namespace so that if namespace not set " |
| "in the connection, then None is returned. To do so, remove get_namespace " |
| "and rename _get_namespace to get_namespace." |
| ) |
| |
| @pytest.mark.parametrize( |
| "conn_id, expected", |
| ( |
| pytest.param("sidecar_container_image", "private.repo.com/alpine:3.16", id="sidecar-with-image"), |
| pytest.param("sidecar_container_image_empty", None, id="sidecar-without-image"), |
| ), |
| ) |
| def test_get_xcom_sidecar_container_image(self, conn_id, expected): |
| hook = KubernetesHook(conn_id=conn_id) |
| assert hook.get_xcom_sidecar_container_image() == expected |
| |
| @pytest.mark.parametrize( |
| "conn_id, expected", |
| ( |
| pytest.param( |
| "sidecar_container_resources", |
| { |
| "requests": {"cpu": "1m", "memory": "10Mi"}, |
| "limits": { |
| "cpu": "1m", |
| "memory": "50Mi", |
| }, |
| }, |
| id="sidecar-with-resources", |
| ), |
| pytest.param("sidecar_container_resources_empty", None, id="sidecar-without-resources"), |
| ), |
| ) |
| def test_get_xcom_sidecar_container_resources(self, conn_id, expected): |
| hook = KubernetesHook(conn_id=conn_id) |
| assert hook.get_xcom_sidecar_container_resources() == expected |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_client_types(self, mock_kube_config_merger, mock_kube_config_loader): |
| hook = KubernetesHook(None) |
| assert isinstance(hook.core_v1_client, kubernetes.client.CoreV1Api) |
| assert isinstance(hook.api_client, kubernetes.client.ApiClient) |
| assert isinstance(hook.get_conn(), kubernetes.client.ApiClient) |
| |
| @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client") |
| def test_prefixed_names_still_work(self, mock_get_client): |
| conn_uri = "kubernetes://?extra__kubernetes__cluster_context=test&extra__kubernetes__namespace=test" |
| with mock.patch.dict("os.environ", AIRFLOW_CONN_KUBERNETES_DEFAULT=conn_uri): |
| kubernetes_hook = KubernetesHook(conn_id="kubernetes_default") |
| kubernetes_hook.get_conn() |
| mock_get_client.assert_called_with(cluster_context="test") |
| assert kubernetes_hook.get_namespace() == "test" |
| |
| def test_missing_default_connection_is_ok(self, remove_default_conn): |
| # prove to ourselves that the default conn doesn't exist |
| with pytest.raises(AirflowNotFoundException): |
| BaseHook.get_connection(DEFAULT_CONN_ID) |
| |
| # verify K8sHook still works |
| hook = KubernetesHook() |
| assert hook.conn_extras == {} |
| |
| # meanwhile, asking for non-default should still fail if it doesn't exist |
| hook = KubernetesHook("some_conn") |
| with pytest.raises(AirflowNotFoundException, match="The conn_id `some_conn` isn't defined"): |
| hook.conn_extras |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch(f"{HOOK_MODULE}.client.CustomObjectsApi") |
| def test_delete_custom_object( |
| self, mock_custom_object_api, mock_kube_config_merger, mock_kube_config_loader |
| ): |
| hook = KubernetesHook() |
| hook.delete_custom_object( |
| group="group", |
| version="version", |
| plural="plural", |
| name="name", |
| namespace="namespace", |
| _preload_content="_preload_content", |
| ) |
| |
| mock_custom_object_api.return_value.delete_namespaced_custom_object.assert_called_once_with( |
| group="group", |
| version="version", |
| plural="plural", |
| name="name", |
| namespace="namespace", |
| _preload_content="_preload_content", |
| ) |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch(f"{HOOK_MODULE}.KubernetesHook.batch_v1_client") |
| def test_get_job_status(self, mock_client, mock_kube_config_merger, mock_kube_config_loader): |
| job_expected = mock_client.read_namespaced_job_status.return_value |
| |
| hook = KubernetesHook() |
| job_actual = hook.get_job_status(job_name=JOB_NAME, namespace=NAMESPACE) |
| |
| mock_client.read_namespaced_job_status.assert_called_once_with( |
| name=JOB_NAME, namespace=NAMESPACE, pretty=True |
| ) |
| assert job_actual == job_expected |
| |
| @pytest.mark.parametrize( |
| "conditions, expected_result", |
| [ |
| (None, False), |
| ([], False), |
| ([mock.MagicMock(type="Complete", status=True)], False), |
| ([mock.MagicMock(type="Complete", status=False)], False), |
| ([mock.MagicMock(type="Failed", status=False)], False), |
| ([mock.MagicMock(type="Failed", status=True, reason="test reason 1")], "test reason 1"), |
| ( |
| [ |
| mock.MagicMock(type="Complete", status=False), |
| mock.MagicMock(type="Failed", status=True, reason="test reason 2"), |
| ], |
| "test reason 2", |
| ), |
| ( |
| [ |
| mock.MagicMock(type="Complete", status=True), |
| mock.MagicMock(type="Failed", status=True, reason="test reason 3"), |
| ], |
| "test reason 3", |
| ), |
| ], |
| ) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_is_job_failed(self, mock_merger, mock_loader, conditions, expected_result): |
| mock_job = mock.MagicMock() |
| mock_job.status.conditions = conditions |
| |
| hook = KubernetesHook() |
| actual_result = hook.is_job_failed(mock_job) |
| |
| assert actual_result == expected_result |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_is_job_failed_no_status(self, mock_merger, mock_loader): |
| mock_job = mock.MagicMock() |
| mock_job.status = None |
| |
| hook = KubernetesHook() |
| job_failed = hook.is_job_failed(mock_job) |
| |
| assert not job_failed |
| |
| @pytest.mark.parametrize( |
| "condition_type, status, expected_result", |
| [ |
| ("Complete", False, False), |
| ("Complete", True, True), |
| ("Failed", False, False), |
| ("Failed", True, False), |
| ("Suspended", False, False), |
| ("Suspended", True, False), |
| ("Unknown", False, False), |
| ("Unknown", True, False), |
| ], |
| ) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_is_job_successful(self, mock_merger, mock_loader, condition_type, status, expected_result): |
| mock_job = mock.MagicMock() |
| mock_job.status.conditions = [mock.MagicMock(type=condition_type, status=status)] |
| |
| hook = KubernetesHook() |
| actual_result = hook.is_job_successful(mock_job) |
| |
| assert actual_result == expected_result |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_is_job_successful_no_status(self, mock_merger, mock_loader): |
| mock_job = mock.MagicMock() |
| mock_job.status = None |
| |
| hook = KubernetesHook() |
| job_successful = hook.is_job_successful(mock_job) |
| |
| assert not job_successful |
| |
| @pytest.mark.parametrize( |
| "condition_type, status, expected_result", |
| [ |
| ("Complete", False, False), |
| ("Complete", True, True), |
| ("Failed", False, False), |
| ("Failed", True, True), |
| ("Suspended", False, False), |
| ("Suspended", True, False), |
| ("Unknown", False, False), |
| ("Unknown", True, False), |
| ], |
| ) |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_is_job_complete(self, mock_merger, mock_loader, condition_type, status, expected_result): |
| mock_job = mock.MagicMock() |
| mock_job.status.conditions = [mock.MagicMock(type=condition_type, status=status)] |
| |
| hook = KubernetesHook() |
| actual_result = hook.is_job_complete(mock_job) |
| |
| assert actual_result == expected_result |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| def test_is_job_complete_no_status(self, mock_merger, mock_loader): |
| mock_job = mock.MagicMock() |
| mock_job.status = None |
| |
| hook = KubernetesHook() |
| job_complete = hook.is_job_complete(mock_job) |
| |
| assert not job_complete |
| |
| @patch("kubernetes.config.kube_config.KubeConfigLoader") |
| @patch("kubernetes.config.kube_config.KubeConfigMerger") |
| @patch(f"{HOOK_MODULE}.KubernetesHook.get_job_status") |
| def test_wait_until_job_complete(self, mock_job_status, mock_kube_config_merger, mock_kube_config_loader): |
| job_expected = mock.MagicMock( |
| status=mock.MagicMock( |
| conditions=[ |
| mock.MagicMock(type="TestType1"), |
| mock.MagicMock(type="TestType2"), |
| mock.MagicMock(type="Complete", status=True), |
| ] |
| ) |
| ) |
| mock_job_status.side_effect = [ |
| mock.MagicMock(status=mock.MagicMock(conditions=None)), |
| mock.MagicMock(status=mock.MagicMock(conditions=[mock.MagicMock(type="TestType")])), |
| mock.MagicMock( |
| status=mock.MagicMock( |
| conditions=[ |
| mock.MagicMock(type="TestType1"), |
| mock.MagicMock(type="TestType2"), |
| ] |
| ) |
| ), |
| mock.MagicMock( |
| status=mock.MagicMock( |
| conditions=[ |
| mock.MagicMock(type="TestType1"), |
| mock.MagicMock(type="TestType2"), |
| mock.MagicMock(type="Complete", status=False), |
| ] |
| ) |
| ), |
| job_expected, |
| ] |
| |
| hook = KubernetesHook() |
| with patch(f"{HOOK_MODULE}.sleep", return_value=None) as mock_sleep: |
| job_actual = hook.wait_until_job_complete( |
| job_name=JOB_NAME, namespace=NAMESPACE, job_poll_interval=POLL_INTERVAL |
| ) |
| |
| mock_job_status.assert_has_calls([mock.call(job_name=JOB_NAME, namespace=NAMESPACE)] * 5) |
| mock_sleep.assert_has_calls([mock.call(POLL_INTERVAL)] * 4) |
| assert job_actual == job_expected |
| |
| @patch(f"{HOOK_MODULE}.json.dumps") |
| @patch(f"{HOOK_MODULE}.KubernetesHook.batch_v1_client") |
| def test_create_job_retries_on_500_error(self, mock_client, mock_json_dumps): |
| mock_client.create_namespaced_job.side_effect = [ |
| ApiException(status=500), |
| MagicMock(), |
| ] |
| |
| hook = KubernetesHook() |
| hook.create_job(job=mock.MagicMock()) |
| |
| assert mock_client.create_namespaced_job.call_count == 2 |
| |
| @patch(f"{HOOK_MODULE}.json.dumps") |
| @patch(f"{HOOK_MODULE}.KubernetesHook.batch_v1_client") |
| def test_create_job_fails_on_other_exception(self, mock_client, mock_json_dumps): |
| mock_client.create_namespaced_job.side_effect = [ApiException(status=404)] |
| |
| hook = KubernetesHook() |
| with pytest.raises(ApiException): |
| hook.create_job(job=mock.MagicMock()) |
| |
| @patch(f"{HOOK_MODULE}.json.dumps") |
| @patch(f"{HOOK_MODULE}.KubernetesHook.batch_v1_client") |
| def test_create_job_retries_three_times(self, mock_client, mock_json_dumps): |
| mock_client.create_namespaced_job.side_effect = [ |
| ApiException(status=500), |
| ApiException(status=500), |
| ApiException(status=500), |
| ApiException(status=500), |
| ] |
| |
| hook = KubernetesHook() |
| with pytest.raises(ApiException): |
| hook.create_job(job=mock.MagicMock()) |
| |
| assert mock_client.create_namespaced_job.call_count == 3 |
| |
| |
| class TestKubernetesHookIncorrectConfiguration: |
| @pytest.mark.parametrize( |
| "conn_uri", |
| ( |
| "kubernetes://?kube_config_path=/tmp/&kube_config=[1,2,3]", |
| "kubernetes://?kube_config_path=/tmp/&in_cluster=[1,2,3]", |
| "kubernetes://?kube_config=/tmp/&in_cluster=[1,2,3]", |
| ), |
| ) |
| def test_should_raise_exception_on_invalid_configuration(self, conn_uri): |
| kubernetes_hook = KubernetesHook() |
| with mock.patch.dict("os.environ", AIRFLOW_CONN_KUBERNETES_DEFAULT=conn_uri), pytest.raises( |
| AirflowException, match="Invalid connection configuration" |
| ): |
| kubernetes_hook.get_conn() |
| |
| |
| class TestAsyncKubernetesHook: |
| KUBE_CONFIG_MERGER = "kubernetes_asyncio.config.kube_config.KubeConfigMerger" |
| INCLUSTER_CONFIG_LOADER = "kubernetes_asyncio.config.incluster_config.InClusterConfigLoader" |
| KUBE_LOADER_CONFIG = "kubernetes_asyncio.config.kube_config.KubeConfigLoader" |
| KUBE_API = "kubernetes_asyncio.client.api.core_v1_api.CoreV1Api.{}" |
| KUBE_BATCH_API = "kubernetes_asyncio.client.api.batch_v1_api.BatchV1Api.{}" |
| KUBE_ASYNC_HOOK = HOOK_MODULE + ".AsyncKubernetesHook.{}" |
| |
| @staticmethod |
| def mock_await_result(return_value): |
| f = Future() |
| f.set_result(return_value) |
| return f |
| |
| @pytest.fixture |
| def kube_config_loader(self): |
| with mock.patch(self.KUBE_LOADER_CONFIG) as kube_config_loader: |
| kube_config_loader.return_value.load_and_set.return_value = self.mock_await_result(None) |
| yield kube_config_loader |
| |
| @staticmethod |
| @pytest.fixture |
| def kubernetes_connection(): |
| extra = {"kube_config": '{"test": "kube"}'} |
| merge_conn( |
| Connection( |
| conn_type="kubernetes", |
| conn_id=CONN_ID, |
| extra=json.dumps(extra), |
| ), |
| ) |
| yield |
| clear_db_connections() |
| |
| @pytest.mark.asyncio |
| @mock.patch(INCLUSTER_CONFIG_LOADER) |
| @mock.patch(KUBE_LOADER_CONFIG) |
| @mock.patch(KUBE_CONFIG_MERGER) |
| async def test_load_config_with_incluster(self, kube_config_merger, kube_config_loader, incluster_config): |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=True, |
| config_file=None, |
| cluster_context=None, |
| ) |
| await hook._load_config() |
| incluster_config.assert_called_once() |
| assert not kube_config_loader.called |
| assert not kube_config_merger.called |
| |
| @pytest.mark.asyncio |
| @mock.patch(INCLUSTER_CONFIG_LOADER) |
| @mock.patch(KUBE_CONFIG_MERGER) |
| async def test_load_config_with_config_path( |
| self, kube_config_merger, incluster_config, kube_config_loader |
| ): |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=ASYNC_CONFIG_PATH, |
| cluster_context=None, |
| ) |
| await hook._load_config() |
| assert not incluster_config.called |
| kube_config_loader.assert_called_once() |
| kube_config_merger.assert_called_once() |
| |
| @pytest.mark.asyncio |
| @mock.patch(INCLUSTER_CONFIG_LOADER) |
| @mock.patch(KUBE_CONFIG_MERGER) |
| async def test_load_config_with_conn_id( |
| self, |
| kube_config_merger, |
| incluster_config, |
| kube_config_loader, |
| kubernetes_connection, |
| ): |
| hook = AsyncKubernetesHook( |
| conn_id=CONN_ID, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| await hook._load_config() |
| assert not incluster_config.called |
| kube_config_loader.assert_called_once() |
| kube_config_merger.assert_called_once() |
| |
| @pytest.mark.asyncio |
| @mock.patch(INCLUSTER_CONFIG_LOADER) |
| @mock.patch(KUBE_CONFIG_MERGER) |
| async def test_load_config_with_default_client( |
| self, |
| kube_config_merger, |
| incluster_config, |
| kube_config_loader, |
| ): |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| kube_client = await hook._load_config() |
| |
| assert not incluster_config.called |
| kube_config_loader.assert_called_once() |
| kube_config_merger.assert_called_once() |
| # It should return None in case when default client is used |
| assert kube_client is None |
| |
| @pytest.mark.asyncio |
| async def test_load_config_with_several_params( |
| self, |
| ): |
| hook = AsyncKubernetesHook( |
| conn_id=CONN_ID, |
| in_cluster=True, |
| config_file=ASYNC_CONFIG_PATH, |
| cluster_context=None, |
| ) |
| with pytest.raises(AirflowException): |
| await hook._load_config() |
| |
| @pytest.mark.asyncio |
| @mock.patch(KUBE_API.format("read_namespaced_pod")) |
| async def test_get_pod(self, lib_method, kube_config_loader): |
| lib_method.return_value = self.mock_await_result(None) |
| |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| await hook.get_pod( |
| name=POD_NAME, |
| namespace=NAMESPACE, |
| ) |
| |
| lib_method.assert_called_once() |
| lib_method.assert_called_with( |
| name=POD_NAME, |
| namespace=NAMESPACE, |
| ) |
| |
| @pytest.mark.asyncio |
| @mock.patch(KUBE_API.format("delete_namespaced_pod")) |
| async def test_delete_pod(self, lib_method, kube_config_loader): |
| lib_method.return_value = self.mock_await_result(None) |
| |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| await hook.delete_pod( |
| name=POD_NAME, |
| namespace=NAMESPACE, |
| ) |
| |
| lib_method.assert_called_once() |
| |
| @pytest.mark.asyncio |
| @mock.patch(KUBE_API.format("read_namespaced_pod_log")) |
| async def test_read_logs(self, lib_method, kube_config_loader, caplog): |
| lib_method.return_value = self.mock_await_result("2023-01-11 Some string logs...") |
| |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| await hook.read_logs( |
| name=POD_NAME, |
| namespace=NAMESPACE, |
| ) |
| |
| lib_method.assert_called_once() |
| lib_method.assert_called_with( |
| name=POD_NAME, |
| namespace=NAMESPACE, |
| follow=False, |
| timestamps=True, |
| ) |
| assert "Container logs from 2023-01-11 Some string logs..." in caplog.text |
| |
| @pytest.mark.asyncio |
| @mock.patch(KUBE_BATCH_API.format("read_namespaced_job_status")) |
| async def test_get_job_status(self, lib_method, kube_config_loader): |
| lib_method.return_value = self.mock_await_result(None) |
| |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| await hook.get_job_status( |
| name=JOB_NAME, |
| namespace=NAMESPACE, |
| ) |
| |
| lib_method.assert_called_once() |
| |
| @pytest.mark.asyncio |
| @mock.patch(HOOK_MODULE + ".asyncio.sleep") |
| @mock.patch(KUBE_ASYNC_HOOK.format("is_job_complete")) |
| @mock.patch(KUBE_ASYNC_HOOK.format("get_job_status")) |
| async def test_wait_until_job_complete( |
| self, mock_get_job_status, mock_is_job_complete, mock_sleep, kube_config_loader |
| ): |
| mock_job_0, mock_job_1 = mock.MagicMock(), mock.MagicMock() |
| mock_get_job_status.side_effect = mock.AsyncMock(side_effect=[mock_job_0, mock_job_1]) |
| mock_is_job_complete.side_effect = [False, True] |
| |
| hook = AsyncKubernetesHook( |
| conn_id=None, |
| in_cluster=False, |
| config_file=None, |
| cluster_context=None, |
| ) |
| |
| job_actual = await hook.wait_until_job_complete( |
| name=JOB_NAME, |
| namespace=NAMESPACE, |
| poll_interval=10, |
| ) |
| |
| mock_get_job_status.assert_has_awaits( |
| [ |
| mock.call(name=JOB_NAME, namespace=NAMESPACE), |
| mock.call(name=JOB_NAME, namespace=NAMESPACE), |
| ] |
| ) |
| mock_is_job_complete.assert_has_calls([mock.call(job=mock_job_0), mock.call(job=mock_job_1)]) |
| mock_sleep.assert_awaited_once_with(10) |
| assert job_actual == mock_job_1 |