| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """DAG Cycle tester""" |
| from collections import defaultdict, deque |
| |
| from airflow.exceptions import AirflowDagCycleException |
| |
| CYCLE_NEW = 0 |
| CYCLE_IN_PROGRESS = 1 |
| CYCLE_DONE = 2 |
| |
| |
| def test_cycle(dag): |
| """ |
| Check to see if there are any cycles in the DAG. Returns False if no cycle found, |
| otherwise raises exception. |
| """ |
| # default of int is 0 which corresponds to CYCLE_NEW |
| visited = defaultdict(int) |
| path_stack = deque() |
| task_dict = dag.task_dict |
| |
| def _check_adjacent_tasks(task_id, current_task): |
| """Returns first untraversed child task, else None if all tasks traversed.""" |
| for adjacent_task in current_task.get_direct_relative_ids(): |
| if visited[adjacent_task] == CYCLE_IN_PROGRESS: |
| msg = f"Cycle detected in DAG. Faulty task: {task_id}" |
| raise AirflowDagCycleException(msg) |
| elif visited[adjacent_task] == CYCLE_NEW: |
| return adjacent_task |
| return None |
| |
| for dag_task_id in dag.task_dict.keys(): |
| if visited[dag_task_id] == CYCLE_DONE: |
| continue |
| path_stack.append(dag_task_id) |
| while path_stack: |
| current_task_id = path_stack[-1] |
| if visited[current_task_id] == CYCLE_NEW: |
| visited[current_task_id] = CYCLE_IN_PROGRESS |
| task = task_dict[current_task_id] |
| child_to_check = _check_adjacent_tasks(current_task_id, task) |
| if not child_to_check: |
| visited[current_task_id] = CYCLE_DONE |
| path_stack.pop() |
| else: |
| path_stack.append(child_to_check) |