| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """ |
| AWS ECS Executor. |
| |
| Each Airflow task gets delegated out to an Amazon ECS Task. |
| """ |
| |
| from __future__ import annotations |
| |
| import time |
| from collections import defaultdict, deque |
| from copy import deepcopy |
| from typing import TYPE_CHECKING, Sequence |
| |
| from botocore.exceptions import ClientError, NoCredentialsError |
| |
| from airflow.configuration import conf |
| from airflow.exceptions import AirflowException |
| from airflow.executors.base_executor import BaseExecutor |
| from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema |
| from airflow.providers.amazon.aws.executors.ecs.utils import ( |
| CONFIG_DEFAULTS, |
| CONFIG_GROUP_NAME, |
| AllEcsConfigKeys, |
| EcsExecutorException, |
| EcsQueuedTask, |
| EcsTaskCollection, |
| ) |
| from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import ( |
| calculate_next_attempt_delay, |
| exponential_backoff_retry, |
| ) |
| from airflow.providers.amazon.aws.hooks.ecs import EcsHook |
| from airflow.stats import Stats |
| from airflow.utils import timezone |
| from airflow.utils.helpers import merge_dicts |
| from airflow.utils.state import State |
| |
| if TYPE_CHECKING: |
| from airflow.models.taskinstance import TaskInstance, TaskInstanceKey |
| from airflow.providers.amazon.aws.executors.ecs.utils import ( |
| CommandType, |
| ExecutorConfigType, |
| ) |
| |
| INVALID_CREDENTIALS_EXCEPTIONS = [ |
| "ExpiredTokenException", |
| "InvalidClientTokenId", |
| "UnrecognizedClientException", |
| ] |
| |
| |
| class AwsEcsExecutor(BaseExecutor): |
| """ |
| Executes the provided Airflow command on an ECS instance. |
| |
| The Airflow Scheduler creates a shell command, and passes it to the executor. This ECS Executor |
| runs said Airflow command on a remote Amazon ECS Cluster with a task-definition configured to |
| launch the same containers as the Scheduler. It then periodically checks in with the launched |
| tasks (via task ARNs) to determine the status. |
| |
| This allows individual tasks to specify CPU, memory, GPU, env variables, etc. When initializing a task, |
| there's an option for "executor config" which should be a dictionary with keys that match the |
| ``ContainerOverride`` definition per AWS documentation (see link below). |
| |
| Prerequisite: proper configuration of Boto3 library |
| .. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html for |
| authentication and access-key management. You can store an environmental variable, setup aws config from |
| console, or use IAM roles. |
| |
| .. seealso:: https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_ContainerOverride.html for an |
| Airflow TaskInstance's executor_config. |
| """ |
| |
| # Maximum number of retries to run an ECS task. |
| MAX_RUN_TASK_ATTEMPTS = conf.get( |
| CONFIG_GROUP_NAME, |
| AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS, |
| fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS], |
| ) |
| |
| # AWS limits the maximum number of ARNs in the describe_tasks function. |
| DESCRIBE_TASKS_BATCH_SIZE = 99 |
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.active_workers: EcsTaskCollection = EcsTaskCollection() |
| self.pending_tasks: deque = deque() |
| |
| self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER) |
| self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME) |
| self.attempts_since_last_successful_connection = 0 |
| |
| self.load_ecs_connection(check_connection=False) |
| self.IS_BOTO_CONNECTION_HEALTHY = False |
| |
| self.run_task_kwargs = self._load_run_kwargs() |
| |
| def start(self): |
| """Call this when the Executor is run for the first time by the scheduler.""" |
| check_health = conf.getboolean( |
| CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False |
| ) |
| |
| if not check_health: |
| return |
| |
| self.log.info("Starting ECS Executor and determining health...") |
| try: |
| self.check_health() |
| except AirflowException: |
| self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.") |
| raise |
| |
| def check_health(self): |
| """ |
| Make a test API call to check the health of the ECS Executor. |
| |
| Deliberately use an invalid task ID, some potential outcomes in order: |
| 1. `AccessDeniedException` is raised if there are insufficient permissions. |
| 2. `ClusterNotFoundException` is raised if permissions exist but the cluster does not. |
| 3. The API responds with a failure message if the cluster is found and there |
| are permissions, but the cluster itself has issues. |
| 4. `InvalidParameterException` is raised if the permissions and cluster exist but the task does not. |
| |
| The last one is considered a success state for the purposes of this check. |
| """ |
| success_status = "succeeded." |
| status = success_status |
| |
| try: |
| invalid_task_id = "a" * 32 |
| self.ecs.stop_task(cluster=self.cluster, task=invalid_task_id) |
| |
| # If it got this far, something is wrong. stop_task() called with an |
| # invalid taskID should have thrown a ClientError. All known reasons are |
| # covered in the ``except`` block below, and this should never be reached. |
| status = "failed for an unknown reason. " |
| except ClientError as ex: |
| error_code = ex.response["Error"]["Code"] |
| error_message = ex.response["Error"]["Message"] |
| |
| if ("InvalidParameterException" in error_code) and ("task was not found" in error_message): |
| # This failure is expected, and means we're healthy |
| pass |
| else: |
| # Catch all for unexpected failures |
| status = f"failed because: {error_message}. " |
| except Exception as e: |
| # Any non-ClientError exceptions. This can include Botocore exceptions for example |
| status = f"failed because: {e}. " |
| finally: |
| msg_prefix = "ECS Executor health check has %s" |
| if status == success_status: |
| self.IS_BOTO_CONNECTION_HEALTHY = True |
| self.log.info(msg_prefix, status) |
| else: |
| msg_error_suffix = ( |
| "The ECS executor will not be able to run Airflow tasks until the issue is addressed." |
| ) |
| raise AirflowException(msg_prefix % status + msg_error_suffix) |
| |
| def load_ecs_connection(self, check_connection: bool = True): |
| self.log.info("Loading Connection information") |
| aws_conn_id = conf.get( |
| CONFIG_GROUP_NAME, |
| AllEcsConfigKeys.AWS_CONN_ID, |
| fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID], |
| ) |
| region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None) |
| self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn |
| self.attempts_since_last_successful_connection += 1 |
| self.last_connection_reload = timezone.utcnow() |
| |
| if check_connection: |
| self.check_health() |
| self.attempts_since_last_successful_connection = 0 |
| |
| def sync(self): |
| if not self.IS_BOTO_CONNECTION_HEALTHY: |
| exponential_backoff_retry( |
| self.last_connection_reload, |
| self.attempts_since_last_successful_connection, |
| self.load_ecs_connection, |
| ) |
| if not self.IS_BOTO_CONNECTION_HEALTHY: |
| return |
| try: |
| self.sync_running_tasks() |
| self.attempt_task_runs() |
| except (ClientError, NoCredentialsError) as error: |
| error_code = error.response["Error"]["Code"] |
| if error_code in INVALID_CREDENTIALS_EXCEPTIONS: |
| self.IS_BOTO_CONNECTION_HEALTHY = False |
| self.log.warning( |
| "AWS credentials are either missing or expired: %s.\nRetrying connection", error |
| ) |
| |
| except Exception: |
| # We catch any and all exceptions because otherwise they would bubble |
| # up and kill the scheduler process |
| self.log.exception("Failed to sync %s", self.__class__.__name__) |
| |
| def sync_running_tasks(self): |
| """Check and update state on all running tasks.""" |
| all_task_arns = self.active_workers.get_all_arns() |
| if not all_task_arns: |
| self.log.debug("No active Airflow tasks, skipping sync.") |
| return |
| |
| describe_tasks_response = self.__describe_tasks(all_task_arns) |
| |
| self.log.debug("Active Workers: %s", describe_tasks_response) |
| |
| if describe_tasks_response["failures"]: |
| for failure in describe_tasks_response["failures"]: |
| self.__handle_failed_task(failure["arn"], failure["reason"]) |
| |
| updated_tasks = describe_tasks_response["tasks"] |
| for task in updated_tasks: |
| self.__update_running_task(task) |
| |
| def __update_running_task(self, task): |
| self.active_workers.update_task(task) |
| # Get state of current task. |
| task_state = task.get_task_state() |
| task_key = self.active_workers.arn_to_key[task.task_arn] |
| |
| # Mark finished tasks as either a success/failure. |
| if task_state == State.FAILED or task_state == State.REMOVED: |
| self.__log_container_failures(task_arn=task.task_arn) |
| self.__handle_failed_task(task.task_arn, task.stopped_reason) |
| elif task_state == State.SUCCESS: |
| self.log.debug( |
| "Airflow task %s marked as %s after running on ECS Task (arn) %s", |
| task_key, |
| task_state, |
| task.task_arn, |
| ) |
| self.success(task_key) |
| self.active_workers.pop_by_key(task_key) |
| |
| def __describe_tasks(self, task_arns): |
| all_task_descriptions = {"tasks": [], "failures": []} |
| for i in range(0, len(task_arns), self.DESCRIBE_TASKS_BATCH_SIZE): |
| batched_task_arns = task_arns[i : i + self.DESCRIBE_TASKS_BATCH_SIZE] |
| if not batched_task_arns: |
| continue |
| boto_describe_tasks = self.ecs.describe_tasks(tasks=batched_task_arns, cluster=self.cluster) |
| describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks) |
| |
| all_task_descriptions["tasks"].extend(describe_tasks_response["tasks"]) |
| all_task_descriptions["failures"].extend(describe_tasks_response["failures"]) |
| return all_task_descriptions |
| |
| def __log_container_failures(self, task_arn: str): |
| """Check if the task failed due to issues with the containers.""" |
| containers = self.active_workers.task_by_arn(task_arn).containers |
| has_exit_codes = all(["exit_code" in x for x in containers]) |
| if not has_exit_codes: |
| return "" |
| reasons = [ |
| f'{container["container_arn"]} - {container["reason"]}' |
| for container in containers |
| if "reason" in container |
| ] |
| if reasons: |
| self.log.warning( |
| "The ECS task failed due to the following containers failing:\n%s", "\n".join(reasons) |
| ) |
| |
| def __handle_failed_task(self, task_arn: str, reason: str): |
| """ |
| If an API failure occurs, the task is rescheduled. |
| |
| This function will determine whether the task has been attempted the appropriate number |
| of times, and determine whether the task should be marked failed or not. The task will |
| be removed active_workers, and marked as FAILED, or set into pending_tasks depending on |
| how many times it has been retried. |
| """ |
| task_key = self.active_workers.arn_to_key[task_arn] |
| task_info = self.active_workers.info_by_key(task_key) |
| task_cmd = task_info.cmd |
| queue = task_info.queue |
| exec_info = task_info.config |
| failure_count = self.active_workers.failure_count_by_key(task_key) |
| if int(failure_count) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): |
| self.log.warning( |
| "Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", |
| task_key, |
| reason, |
| failure_count, |
| self.__class__.MAX_RUN_TASK_ATTEMPTS, |
| task_arn, |
| ) |
| self.pending_tasks.append( |
| EcsQueuedTask( |
| task_key, |
| task_cmd, |
| queue, |
| exec_info, |
| failure_count + 1, |
| timezone.utcnow() + calculate_next_attempt_delay(failure_count), |
| ) |
| ) |
| else: |
| self.log.error( |
| "Airflow task %s has failed a maximum of %s times. Marking as failed", |
| task_key, |
| failure_count, |
| ) |
| self.fail(task_key) |
| self.active_workers.pop_by_key(task_key) |
| |
| def attempt_task_runs(self): |
| """ |
| Take tasks from the pending_tasks queue, and attempts to find an instance to run it on. |
| |
| If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If |
| there are no EC2 instances available, no task is placed and this function will be |
| called again in the next heart-beat. |
| |
| If the launch type is FARGATE, this will run the tasks on new AWS Fargate instances. |
| """ |
| queue_len = len(self.pending_tasks) |
| failure_reasons = defaultdict(int) |
| for _ in range(queue_len): |
| ecs_task = self.pending_tasks.popleft() |
| task_key = ecs_task.key |
| cmd = ecs_task.command |
| queue = ecs_task.queue |
| exec_config = ecs_task.executor_config |
| attempt_number = ecs_task.attempt_number |
| _failure_reasons = [] |
| if timezone.utcnow() < ecs_task.next_attempt_time: |
| self.pending_tasks.append(ecs_task) |
| continue |
| try: |
| run_task_response = self._run_task(task_key, cmd, queue, exec_config) |
| except NoCredentialsError: |
| self.pending_tasks.append(ecs_task) |
| raise |
| except ClientError as e: |
| error_code = e.response["Error"]["Code"] |
| if error_code in INVALID_CREDENTIALS_EXCEPTIONS: |
| self.pending_tasks.append(ecs_task) |
| raise |
| _failure_reasons.append(str(e)) |
| except Exception as e: |
| # Failed to even get a response back from the Boto3 API or something else went |
| # wrong. For any possible failure we want to add the exception reasons to the |
| # failure list so that it is logged to the user and most importantly the task is |
| # added back to the pending list to be retried later. |
| _failure_reasons.append(str(e)) |
| else: |
| # We got a response back, check if there were failures. If so, add them to the |
| # failures list so that it is logged to the user and most importantly the task |
| # is added back to the pending list to be retried later. |
| if run_task_response["failures"]: |
| _failure_reasons.extend([f["reason"] for f in run_task_response["failures"]]) |
| |
| if _failure_reasons: |
| for reason in _failure_reasons: |
| failure_reasons[reason] += 1 |
| # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS |
| if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): |
| ecs_task.attempt_number += 1 |
| ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( |
| attempt_number |
| ) |
| self.pending_tasks.append(ecs_task) |
| else: |
| self.log.error( |
| "ECS task %s has failed a maximum of %s times. Marking as failed", |
| task_key, |
| attempt_number, |
| ) |
| self.fail(task_key) |
| elif not run_task_response["tasks"]: |
| self.log.error("ECS RunTask Response: %s", run_task_response) |
| raise EcsExecutorException( |
| "No failures and no ECS tasks provided in response. This should never happen." |
| ) |
| else: |
| task = run_task_response["tasks"][0] |
| self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number) |
| try: |
| self.running_state(task_key, task.task_arn) |
| except AttributeError: |
| # running_state is newly added, and only needed to support task adoption (an optional |
| # executor feature). |
| pass |
| if failure_reasons: |
| self.log.error( |
| "Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.", |
| dict(failure_reasons), |
| ) |
| |
| def _run_task( |
| self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType |
| ): |
| """ |
| Run a queued-up Airflow task. |
| |
| Not to be confused with execute_async() which inserts tasks into the queue. |
| The command and executor config will be placed in the container-override |
| section of the JSON request before calling Boto3's "run_task" function. |
| """ |
| run_task_kwargs = self._run_task_kwargs(task_id, cmd, queue, exec_config) |
| boto_run_task = self.ecs.run_task(**run_task_kwargs) |
| run_task_response = BotoRunTaskSchema().load(boto_run_task) |
| return run_task_response |
| |
| def _run_task_kwargs( |
| self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType |
| ) -> dict: |
| """ |
| Update the Airflow command by modifying container overrides for task-specific kwargs. |
| |
| One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client. |
| """ |
| run_task_kwargs = deepcopy(self.run_task_kwargs) |
| run_task_kwargs = merge_dicts(run_task_kwargs, exec_config) |
| container_override = self.get_container(run_task_kwargs["overrides"]["containerOverrides"]) |
| container_override["command"] = cmd |
| |
| # Inject the env variable to configure logging for containerized execution environment |
| if "environment" not in container_override: |
| container_override["environment"] = [] |
| container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}) |
| |
| return run_task_kwargs |
| |
| def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None): |
| """Save the task to be executed in the next sync by inserting the commands into a queue.""" |
| if executor_config and ("name" in executor_config or "command" in executor_config): |
| raise ValueError('Executor Config should never override "name" or "command"') |
| self.pending_tasks.append( |
| EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow()) |
| ) |
| |
| def end(self, heartbeat_interval=10): |
| """Wait for all currently running tasks to end, and don't launch any tasks.""" |
| try: |
| while True: |
| self.sync() |
| if not self.active_workers: |
| break |
| time.sleep(heartbeat_interval) |
| except Exception: |
| # We catch any and all exceptions because otherwise they would bubble |
| # up and kill the scheduler process. |
| self.log.exception("Failed to end %s", self.__class__.__name__) |
| |
| def terminate(self): |
| """Kill all ECS processes by calling Boto3's StopTask API.""" |
| try: |
| for arn in self.active_workers.get_all_arns(): |
| self.ecs.stop_task( |
| cluster=self.cluster, task=arn, reason="Airflow Executor received a SIGTERM" |
| ) |
| self.end() |
| except Exception: |
| # We catch any and all exceptions because otherwise they would bubble |
| # up and kill the scheduler process. |
| self.log.exception("Failed to terminate %s", self.__class__.__name__) |
| |
| def _load_run_kwargs(self) -> dict: |
| from airflow.providers.amazon.aws.executors.ecs.ecs_executor_config import build_task_kwargs |
| |
| ecs_executor_run_task_kwargs = build_task_kwargs() |
| |
| try: |
| self.get_container(ecs_executor_run_task_kwargs["overrides"]["containerOverrides"])["command"] |
| except KeyError: |
| raise KeyError( |
| "Rendered JSON template does not contain key " |
| '"overrides[containerOverrides][containers][x][command]"' |
| ) |
| return ecs_executor_run_task_kwargs |
| |
| def get_container(self, container_list): |
| """Search task list for core Airflow container.""" |
| for container in container_list: |
| try: |
| if container["name"] == self.container_name: |
| return container |
| except KeyError: |
| raise EcsExecutorException( |
| 'container "name" must be provided in "containerOverrides" configuration' |
| ) |
| raise KeyError(f"No such container found by container name: {self.container_name}") |
| |
| def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: |
| """ |
| Adopt task instances which have an external_executor_id (the ECS task ARN). |
| |
| Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling. |
| """ |
| with Stats.timer("ecs_executor.adopt_task_instances.duration"): |
| adopted_tis: list[TaskInstance] = [] |
| |
| if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]: |
| task_descriptions = self.__describe_tasks(task_arns).get("tasks", []) |
| |
| for task in task_descriptions: |
| ti = next(ti for ti in tis if ti.external_executor_id == task.task_arn) |
| self.active_workers.add_task( |
| task, |
| ti.key, |
| ti.queue, |
| ti.command_as_list(), |
| ti.executor_config, |
| ti.try_number, |
| ) |
| adopted_tis.append(ti) |
| |
| if adopted_tis: |
| tasks = [f"{task} in state {task.state}" for task in adopted_tis] |
| task_instance_str = "\n\t".join(tasks) |
| self.log.info( |
| "Adopted the following %d tasks from a dead executor:\n\t%s", |
| len(adopted_tis), |
| task_instance_str, |
| ) |
| |
| not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] |
| return not_adopted_tis |