blob: 16f49a769d4c75e22a51bc231c3529d91192ffcc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest.mock import AsyncMock, Mock, call, patch
import pytest
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateClusterTrigger,
EksDeleteClusterTrigger,
)
from airflow.triggers.base import TriggerEvent
EXCEPTION_MOCK = AirflowException("MOCK ERROR")
CLUSTER_NAME = "test_cluster"
WAITER_DELAY = 1
WAITER_MAX_ATTEMPTS = 10
AWS_CONN_ID = "test_conn_id"
REGION_NAME = "test-region"
FARGATE_PROFILES = ["p1", "p2"]
class TestEksTrigger:
def setup_method(self):
self.async_conn_patcher = patch("airflow.providers.amazon.aws.hooks.eks.EksHook.async_conn")
self.mock_async_conn = self.async_conn_patcher.start()
self.mock_client = AsyncMock()
self.mock_async_conn.__aenter__.return_value = self.mock_client
self.async_wait_patcher = patch(
"airflow.providers.amazon.aws.triggers.eks.async_wait", return_value=True
)
self.mock_async_wait = self.async_wait_patcher.start()
def teardown_method(self):
self.async_conn_patcher.stop()
self.async_wait_patcher.stop()
class TestEksCreateClusterTrigger(TestEksTrigger):
def setup_method(self):
super().setup_method()
self.trigger = EksCreateClusterTrigger(
cluster_name=CLUSTER_NAME,
waiter_delay=WAITER_DELAY,
waiter_max_attempts=WAITER_MAX_ATTEMPTS,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
)
self.trigger.log.error = Mock()
@pytest.mark.asyncio
async def test_when_cluster_is_created_run_should_return_a_success_event(self):
generator = self.trigger.run()
response = await generator.asend(None)
assert response == TriggerEvent({"status": "success"})
@pytest.mark.asyncio
async def test_when_run_raises_exception_it_should_return_a_failure_event(self):
self.mock_async_wait.side_effect = EXCEPTION_MOCK
generator = self.trigger.run()
response = await generator.asend(None)
assert response == TriggerEvent({"status": "failed"})
self.trigger.log.error.assert_called_once_with("Error creating cluster: %s", EXCEPTION_MOCK)
@pytest.mark.asyncio
async def test_run_parameterizes_async_wait_correctly(self):
self.mock_client.get_waiter = Mock(return_value="waiter")
generator = self.trigger.run()
await generator.asend(None)
self.mock_client.get_waiter.assert_called_once_with("cluster_active")
self.mock_async_wait.assert_called_once_with(
"waiter",
WAITER_DELAY,
WAITER_MAX_ATTEMPTS,
{"name": CLUSTER_NAME},
"Error checking Eks cluster",
"Eks cluster status is",
["cluster.status"],
)
class TestEksDeleteClusterTriggerRun(TestEksTrigger):
def setup_method(self):
super().setup_method()
self.delete_any_nodegroups_patcher = patch.object(EksDeleteClusterTrigger, "delete_any_nodegroups")
self.mock_delete_any_nodegroups = self.delete_any_nodegroups_patcher.start()
self.delete_any_fargate_profiles_patcher = patch.object(
EksDeleteClusterTrigger, "delete_any_fargate_profiles"
)
self.mock_delete_any_fargate_profiles = self.delete_any_fargate_profiles_patcher.start()
self.trigger = EksDeleteClusterTrigger(
cluster_name=CLUSTER_NAME,
waiter_delay=WAITER_DELAY,
waiter_max_attempts=WAITER_MAX_ATTEMPTS,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
force_delete_compute=False,
)
self.trigger.log.info = Mock()
def teardown_method(self):
super().teardown_method()
self.delete_any_nodegroups_patcher.stop()
self.delete_any_fargate_profiles_patcher.stop()
@pytest.mark.asyncio
async def test_run_deletes_nodegroups_and_fargate_profiles(self):
self.trigger.force_delete_compute = True
generator = self.trigger.run()
response = await generator.asend(None)
self.mock_delete_any_nodegroups.assert_called_once_with(client=self.mock_client)
self.mock_delete_any_fargate_profiles.assert_called_once_with(client=self.mock_client)
assert response == TriggerEvent({"status": "deleted"})
@pytest.mark.asyncio
async def test_when_resource_is_not_found_it_should_return_status_deleted(self):
delete_cluster_mock = AsyncMock(
side_effect=ClientError({"Error": {"Code": "ResourceNotFoundException"}}, "delete_eks_cluster")
)
self.mock_client.delete_cluster = delete_cluster_mock
generator = self.trigger.run()
response = await generator.asend(None)
delete_cluster_mock.assert_called_once_with(name=CLUSTER_NAME)
assert response == TriggerEvent({"status": "deleted"})
@pytest.mark.asyncio
async def test_run_raises_client_error(self):
response = {"Error": {"Code": "OtherException"}}
operation_name = "delete_eks_cluster"
delete_cluster_mock = AsyncMock(side_effect=ClientError(response, "delete_eks_cluster"))
self.mock_client.delete_cluster = delete_cluster_mock
generator = self.trigger.run()
with pytest.raises(ClientError) as exception:
await generator.asend(None)
delete_cluster_mock.assert_called_once_with(name=CLUSTER_NAME)
assert exception._excinfo[1].response == response
assert exception._excinfo[1].operation_name == operation_name
@pytest.mark.asyncio
async def test_run_parameterizes_async_wait_correctly(self):
self.mock_client.get_waiter = Mock(return_value="waiter")
generator = self.trigger.run()
await generator.asend(None)
self.mock_delete_any_fargate_profiles.assert_not_called()
self.mock_delete_any_nodegroups.assert_not_called()
self.mock_client.get_waiter.assert_called_once_with("cluster_deleted")
self.mock_async_wait.assert_called_once_with(
waiter="waiter",
waiter_delay=WAITER_DELAY,
waiter_max_attempts=WAITER_MAX_ATTEMPTS,
args={"name": CLUSTER_NAME},
failure_message="Error deleting cluster",
status_message="Status of cluster is",
status_args=["cluster.status"],
)
class TestEksDeleteClusterTriggerDeleteNodegroupsAndFargateProfiles(TestEksTrigger):
def setup_method(self):
super().setup_method()
self.get_waiter_patcher = patch(
"airflow.providers.amazon.aws.hooks.eks.EksHook.get_waiter", return_value="waiter"
)
self.mock_waiter = self.get_waiter_patcher.start()
self.trigger = EksDeleteClusterTrigger(
cluster_name=CLUSTER_NAME,
waiter_delay=WAITER_DELAY,
waiter_max_attempts=WAITER_MAX_ATTEMPTS,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
force_delete_compute=False,
)
self.trigger.log.info = Mock()
def teardown_method(self):
super().teardown_method()
self.get_waiter_patcher.stop()
@pytest.mark.asyncio
async def test_delete_nodegroups(self):
mock_list_node_groups = AsyncMock(return_value={"nodegroups": ["g1", "g2"]})
mock_delete_nodegroup = AsyncMock()
mock_client = AsyncMock(list_nodegroups=mock_list_node_groups, delete_nodegroup=mock_delete_nodegroup)
await self.trigger.delete_any_nodegroups(mock_client)
mock_list_node_groups.assert_called_once_with(clusterName=CLUSTER_NAME)
self.trigger.log.info.assert_has_calls([call("Deleting nodegroups"), call("All nodegroups deleted")])
self.mock_waiter.assert_called_once_with(
"all_nodegroups_deleted", deferrable=True, client=mock_client
)
mock_delete_nodegroup.assert_has_calls(
[
call(clusterName=CLUSTER_NAME, nodegroupName="g1"),
call(clusterName=CLUSTER_NAME, nodegroupName="g2"),
]
)
self.mock_async_wait.assert_called_once_with(
waiter="waiter",
waiter_delay=WAITER_DELAY,
waiter_max_attempts=WAITER_MAX_ATTEMPTS,
args={"clusterName": CLUSTER_NAME},
failure_message="Error deleting nodegroup for cluster test_cluster",
status_message="Deleting nodegroups associated with the cluster",
status_args=["nodegroups"],
)
@pytest.mark.asyncio
async def test_when_there_are_no_nodegroups_it_should_only_log_message(self):
mock_list_node_groups = AsyncMock(return_value={"nodegroups": []})
mock_delete_nodegroup = AsyncMock()
mock_client = AsyncMock(list_nodegroups=mock_list_node_groups, delete_nodegroup=mock_delete_nodegroup)
await self.trigger.delete_any_nodegroups(mock_client)
mock_list_node_groups.assert_called_once_with(clusterName=CLUSTER_NAME)
self.mock_async_wait.assert_not_called()
mock_delete_nodegroup.assert_not_called()
self.trigger.log.info.assert_called_once_with(
"No nodegroups associated with cluster %s", CLUSTER_NAME
)
@pytest.mark.asyncio
async def test_delete_any_fargate_profiles(self):
mock_list_fargate_profiles = AsyncMock(return_value={"fargateProfileNames": FARGATE_PROFILES})
mock_delete_fargate_profile = AsyncMock()
mock_client = AsyncMock(
list_fargate_profiles=mock_list_fargate_profiles,
delete_fargate_profile=mock_delete_fargate_profile,
get_waiter=self.mock_waiter,
)
await self.trigger.delete_any_fargate_profiles(mock_client)
mock_list_fargate_profiles.assert_called_once_with(clusterName=CLUSTER_NAME)
self.trigger.log.info.assert_has_calls(
[
call("Waiting for Fargate profiles to delete. This will take some time."),
call("All Fargate profiles deleted"),
]
)
self.mock_waiter.assert_has_calls([call("fargate_profile_deleted"), call("fargate_profile_deleted")])
mock_delete_fargate_profile.assert_has_calls(
[
call(clusterName=CLUSTER_NAME, fargateProfileName="p1"),
call(clusterName=CLUSTER_NAME, fargateProfileName="p2"),
]
)
self.mock_async_wait.assert_has_calls(
[
call(
waiter="waiter",
waiter_delay=WAITER_DELAY,
waiter_max_attempts=WAITER_MAX_ATTEMPTS,
args={"clusterName": CLUSTER_NAME, "fargateProfileName": profile},
failure_message="Error deleting fargate profile for cluster test_cluster",
status_message="Status of fargate profile is",
status_args=["fargateProfile.status"],
)
for profile in FARGATE_PROFILES
]
)
@pytest.mark.asyncio
async def test_when_there_are_no_fargate_profiles_it_should_only_log_message(self):
mock_list_fargate_profiles = AsyncMock(return_value={"fargateProfileNames": []})
mock_delete_fargate_profile = AsyncMock()
mock_client = AsyncMock(
list_fargate_profiles=mock_list_fargate_profiles,
delete_fargate_profile=mock_delete_fargate_profile,
)
await self.trigger.delete_any_fargate_profiles(mock_client)
mock_list_fargate_profiles.assert_called_once_with(clusterName=CLUSTER_NAME)
self.mock_async_wait.assert_not_called()
mock_delete_fargate_profile.assert_not_called()
self.trigger.log.info.assert_called_once_with(
"No Fargate profiles associated with cluster %s", CLUSTER_NAME
)