blob: 470b4cdaab49e1fc5c2533300a4959a43a0f8831 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from airflow.models import Operator
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.dag import DAG
def dag_edges(dag: DAG):
"""
Create the list of edges needed to construct the Graph view.
A special case is made if a TaskGroup is immediately upstream/downstream of another
TaskGroup or task. Two proxy nodes named upstream_join_id and downstream_join_id are
created for the TaskGroup. Instead of drawing an edge onto every task in the TaskGroup,
all edges are directed onto the proxy nodes. This is to cut down the number of edges on
the graph.
For example: A DAG with TaskGroups group1 and group2:
group1: task1, task2, task3
group2: task4, task5, task6
group2 is downstream of group1:
group1 >> group2
Edges to add (This avoids having to create edges between every task in group1 and group2):
task1 >> downstream_join_id
task2 >> downstream_join_id
task3 >> downstream_join_id
downstream_join_id >> upstream_join_id
upstream_join_id >> task4
upstream_join_id >> task5
upstream_join_id >> task6
"""
# Edges to add between TaskGroup
edges_to_add = set()
# Edges to remove between individual tasks that are replaced by edges_to_add.
edges_to_skip = set()
task_group_map = dag.task_group.get_task_group_dict()
def collect_edges(task_group):
"""Update edges_to_add and edges_to_skip according to TaskGroups."""
if isinstance(task_group, AbstractOperator):
return
for target_id in task_group.downstream_group_ids:
# For every TaskGroup immediately downstream, add edges between downstream_join_id
# and upstream_join_id. Skip edges between individual tasks of the TaskGroups.
target_group = task_group_map[target_id]
edges_to_add.add((task_group.downstream_join_id, target_group.upstream_join_id))
for child in task_group.get_leaves():
edges_to_add.add((child.task_id, task_group.downstream_join_id))
for target in target_group.get_roots():
edges_to_skip.add((child.task_id, target.task_id))
edges_to_skip.add((child.task_id, target_group.upstream_join_id))
for child in target_group.get_roots():
edges_to_add.add((target_group.upstream_join_id, child.task_id))
edges_to_skip.add((task_group.downstream_join_id, child.task_id))
# For every individual task immediately downstream, add edges between downstream_join_id and
# the downstream task. Skip edges between individual tasks of the TaskGroup and the
# downstream task.
for target_id in task_group.downstream_task_ids:
edges_to_add.add((task_group.downstream_join_id, target_id))
for child in task_group.get_leaves():
edges_to_add.add((child.task_id, task_group.downstream_join_id))
edges_to_skip.add((child.task_id, target_id))
# For every individual task immediately upstream, add edges between the upstream task
# and upstream_join_id. Skip edges between the upstream task and individual tasks
# of the TaskGroup.
for source_id in task_group.upstream_task_ids:
edges_to_add.add((source_id, task_group.upstream_join_id))
for child in task_group.get_roots():
edges_to_add.add((task_group.upstream_join_id, child.task_id))
edges_to_skip.add((source_id, child.task_id))
for child in task_group.children.values():
collect_edges(child)
collect_edges(dag.task_group)
# Collect all the edges between individual tasks
edges = set()
tasks_to_trace: list[Operator] = dag.roots
while tasks_to_trace:
tasks_to_trace_next: list[Operator] = []
for task in tasks_to_trace:
for child in task.downstream_list:
edge = (task.task_id, child.task_id)
if edge in edges:
continue
edges.add(edge)
tasks_to_trace_next.append(child)
tasks_to_trace = tasks_to_trace_next
result = []
# Build result dicts with the two ends of the edge, plus any extra metadata
# if we have it.
for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip):
record = {"source_id": source_id, "target_id": target_id}
label = dag.get_edge_info(source_id, target_id).get("label")
if label:
record["label"] = label
result.append(record)
return result