| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| from airflow.models import Operator |
| from airflow.models.abstractoperator import AbstractOperator |
| from airflow.models.dag import DAG |
| |
| |
| def dag_edges(dag: DAG): |
| """ |
| Create the list of edges needed to construct the Graph view. |
| |
| A special case is made if a TaskGroup is immediately upstream/downstream of another |
| TaskGroup or task. Two proxy nodes named upstream_join_id and downstream_join_id are |
| created for the TaskGroup. Instead of drawing an edge onto every task in the TaskGroup, |
| all edges are directed onto the proxy nodes. This is to cut down the number of edges on |
| the graph. |
| |
| For example: A DAG with TaskGroups group1 and group2: |
| group1: task1, task2, task3 |
| group2: task4, task5, task6 |
| |
| group2 is downstream of group1: |
| group1 >> group2 |
| |
| Edges to add (This avoids having to create edges between every task in group1 and group2): |
| task1 >> downstream_join_id |
| task2 >> downstream_join_id |
| task3 >> downstream_join_id |
| downstream_join_id >> upstream_join_id |
| upstream_join_id >> task4 |
| upstream_join_id >> task5 |
| upstream_join_id >> task6 |
| """ |
| # Edges to add between TaskGroup |
| edges_to_add = set() |
| # Edges to remove between individual tasks that are replaced by edges_to_add. |
| edges_to_skip = set() |
| |
| task_group_map = dag.task_group.get_task_group_dict() |
| |
| def collect_edges(task_group): |
| """Update edges_to_add and edges_to_skip according to TaskGroups.""" |
| if isinstance(task_group, AbstractOperator): |
| return |
| |
| for target_id in task_group.downstream_group_ids: |
| # For every TaskGroup immediately downstream, add edges between downstream_join_id |
| # and upstream_join_id. Skip edges between individual tasks of the TaskGroups. |
| target_group = task_group_map[target_id] |
| edges_to_add.add((task_group.downstream_join_id, target_group.upstream_join_id)) |
| |
| for child in task_group.get_leaves(): |
| edges_to_add.add((child.task_id, task_group.downstream_join_id)) |
| for target in target_group.get_roots(): |
| edges_to_skip.add((child.task_id, target.task_id)) |
| edges_to_skip.add((child.task_id, target_group.upstream_join_id)) |
| |
| for child in target_group.get_roots(): |
| edges_to_add.add((target_group.upstream_join_id, child.task_id)) |
| edges_to_skip.add((task_group.downstream_join_id, child.task_id)) |
| |
| # For every individual task immediately downstream, add edges between downstream_join_id and |
| # the downstream task. Skip edges between individual tasks of the TaskGroup and the |
| # downstream task. |
| for target_id in task_group.downstream_task_ids: |
| edges_to_add.add((task_group.downstream_join_id, target_id)) |
| |
| for child in task_group.get_leaves(): |
| edges_to_add.add((child.task_id, task_group.downstream_join_id)) |
| edges_to_skip.add((child.task_id, target_id)) |
| |
| # For every individual task immediately upstream, add edges between the upstream task |
| # and upstream_join_id. Skip edges between the upstream task and individual tasks |
| # of the TaskGroup. |
| for source_id in task_group.upstream_task_ids: |
| edges_to_add.add((source_id, task_group.upstream_join_id)) |
| for child in task_group.get_roots(): |
| edges_to_add.add((task_group.upstream_join_id, child.task_id)) |
| edges_to_skip.add((source_id, child.task_id)) |
| |
| for child in task_group.children.values(): |
| collect_edges(child) |
| |
| collect_edges(dag.task_group) |
| |
| # Collect all the edges between individual tasks |
| edges = set() |
| |
| tasks_to_trace: list[Operator] = dag.roots |
| while tasks_to_trace: |
| tasks_to_trace_next: list[Operator] = [] |
| for task in tasks_to_trace: |
| for child in task.downstream_list: |
| edge = (task.task_id, child.task_id) |
| if edge in edges: |
| continue |
| edges.add(edge) |
| tasks_to_trace_next.append(child) |
| tasks_to_trace = tasks_to_trace_next |
| |
| result = [] |
| # Build result dicts with the two ends of the edge, plus any extra metadata |
| # if we have it. |
| for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip): |
| record = {"source_id": source_id, "target_id": target_id} |
| label = dag.get_edge_info(source_id, target_id).get("label") |
| if label: |
| record["label"] = label |
| result.append(record) |
| return result |