blob: 85ec1eb0d3b09e2561eadc27f43f148288fd4e86 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
from __future__ import annotations
import copy
import functools
import operator
import re
import weakref
from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence
from airflow.compat.functools import cache
from airflow.exceptions import (
AirflowDagCycleException,
AirflowException,
DuplicateTaskIdFound,
TaskAlreadyInTaskGroup,
)
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.expandinput import ExpandInput
from airflow.models.operator import Operator
from airflow.utils.edgemodifier import EdgeModifier
class TaskGroup(DAGNode):
"""
A collection of tasks. When set_downstream() or set_upstream() are called on the
TaskGroup, it is applied across all tasks within the group if necessary.
:param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
set to None.
:param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
Default is True.
:param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
for the root TaskGroup.
:param dag: The DAG that this TaskGroup belongs to.
:param default_args: A dictionary of default parameters to be used
as constructor keyword parameters when initialising operators,
will override default_args defined in the DAG level.
Note that operators have the same hook, and precede those defined
here, meaning that if your dict contains `'depends_on_past': True`
here and `'depends_on_past': False` in the operator's call
`default_args`, the actual value will be `False`.
:param tooltip: The tooltip of the TaskGroup node when displayed in the UI
:param ui_color: The fill color of the TaskGroup node when displayed in the UI
:param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
:param add_suffix_on_collision: If this task group name already exists,
automatically add `__1` etc suffixes
"""
used_group_ids: set[str | None]
def __init__(
self,
group_id: str | None,
prefix_group_id: bool = True,
parent_group: TaskGroup | None = None,
dag: DAG | None = None,
default_args: dict[str, Any] | None = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
):
from airflow.models.dag import DagContext
self.prefix_group_id = prefix_group_id
self.default_args = copy.deepcopy(default_args or {})
dag = dag or DagContext.get_current_dag()
if group_id is None:
# This creates a root TaskGroup.
if parent_group:
raise AirflowException("Root TaskGroup cannot have parent_group")
# used_group_ids is shared across all TaskGroups in the same DAG to keep track
# of used group_id to avoid duplication.
self.used_group_ids = set()
self.dag = dag
else:
if prefix_group_id:
# If group id is used as prefix, it should not contain spaces nor dots
# because it is used as prefix in the task_id
validate_group_key(group_id)
else:
if not isinstance(group_id, str):
raise ValueError("group_id must be str")
if not group_id:
raise ValueError("group_id must not be empty")
if not parent_group and not dag:
raise AirflowException("TaskGroup can only be used inside a dag")
parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
if not parent_group:
raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
if dag is not parent_group.dag:
raise RuntimeError(
"Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
)
self.used_group_ids = parent_group.used_group_ids
# if given group_id already used assign suffix by incrementing largest used suffix integer
# Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
self._group_id = group_id
self._check_for_group_id_collisions(add_suffix_on_collision)
self.children: dict[str, DAGNode] = {}
if parent_group:
parent_group.add(self)
self.used_group_ids.add(self.group_id)
if self.group_id:
self.used_group_ids.add(self.downstream_join_id)
self.used_group_ids.add(self.upstream_join_id)
self.tooltip = tooltip
self.ui_color = ui_color
self.ui_fgcolor = ui_fgcolor
# Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
# so that we can optimize the number of edges when entire TaskGroups depend on each other.
self.upstream_group_ids: set[str | None] = set()
self.downstream_group_ids: set[str | None] = set()
self.upstream_task_ids = set()
self.downstream_task_ids = set()
def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
if self._group_id is None:
return
# if given group_id already used assign suffix by incrementing largest used suffix integer
# Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
if self._group_id in self.used_group_ids:
if not add_suffix_on_collision:
raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
base = re.split(r"__\d+$", self._group_id)[0]
suffixes = sorted(
int(re.split(r"^.+__", used_group_id)[1])
for used_group_id in self.used_group_ids
if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
)
if not suffixes:
self._group_id += "__1"
else:
self._group_id = f"{base}__{suffixes[-1] + 1}"
@classmethod
def create_root(cls, dag: DAG) -> TaskGroup:
"""Create a root TaskGroup with no group_id or parent."""
return cls(group_id=None, dag=dag)
@property
def node_id(self):
return self.group_id
@property
def is_root(self) -> bool:
"""Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
return not self.group_id
@property
def parent_group(self) -> TaskGroup | None:
return self.task_group
def __iter__(self):
for child in self.children.values():
if isinstance(child, TaskGroup):
yield from child
else:
yield child
def add(self, task: DAGNode) -> None:
"""Add a task to this TaskGroup.
:meta private:
"""
from airflow.models.abstractoperator import AbstractOperator
existing_tg = task.task_group
if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
# Set the TG first, as setting it might change the return value of node_id!
task.task_group = weakref.proxy(self)
key = task.node_id
if key in self.children:
node_type = "Task" if hasattr(task, "task_id") else "Task Group"
raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")
if isinstance(task, TaskGroup):
if self.dag:
if task.dag is not None and self.dag is not task.dag:
raise RuntimeError(
"Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag
)
task.dag = self.dag
if task.children:
raise AirflowException("Cannot add a non-empty TaskGroup")
self.children[key] = task
def _remove(self, task: DAGNode) -> None:
key = task.node_id
if key not in self.children:
raise KeyError(f"Node id {key!r} not part of this task group")
self.used_group_ids.remove(key)
del self.children[key]
@property
def group_id(self) -> str | None:
"""group_id of this TaskGroup."""
if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id:
# defer to parent whether it adds a prefix
return self.task_group.child_id(self._group_id)
return self._group_id
@property
def label(self) -> str | None:
"""group_id excluding parent's group_id used as the node label in UI."""
return self._group_id
def update_relative(
self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
) -> None:
"""
Overrides TaskMixin.update_relative.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
accordingly so that we can reduce the number of edges when displaying Graph view.
"""
if isinstance(other, TaskGroup):
# Handles setting relationship between a TaskGroup and another TaskGroup
if upstream:
parent, child = (self, other)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
else:
parent, child = (other, self)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
parent.upstream_group_ids.add(child.group_id)
child.downstream_group_ids.add(parent.group_id)
else:
# Handles setting relationship between a TaskGroup and a task
for task in other.roots:
if not isinstance(task, DAGNode):
raise AirflowException(
"Relationships can only be set between TaskGroup "
f"or operators; received {task.__class__.__name__}"
)
# Do not set a relationship between a TaskGroup and a Label's roots
if self == task:
continue
if upstream:
self.upstream_task_ids.add(task.node_id)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
else:
self.downstream_task_ids.add(task.node_id)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
def _set_relatives(
self,
task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
upstream: bool = False,
edge_modifier: EdgeModifier | None = None,
) -> None:
"""
Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
"""
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]
for task_like in task_or_task_list:
self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
if upstream:
for task in self.get_roots():
task.set_upstream(task_or_task_list)
else:
for task in self.get_leaves():
task.set_downstream(task_or_task_list)
def __enter__(self) -> TaskGroup:
TaskGroupContext.push_context_managed_task_group(self)
return self
def __exit__(self, _type, _value, _tb):
TaskGroupContext.pop_context_managed_task_group()
def has_task(self, task: BaseOperator) -> bool:
"""Returns True if this TaskGroup or its children TaskGroups contains the given task."""
if task.task_id in self.children:
return True
return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
@property
def roots(self) -> list[BaseOperator]:
"""Required by TaskMixin."""
return list(self.get_roots())
@property
def leaves(self) -> list[BaseOperator]:
"""Required by TaskMixin."""
return list(self.get_leaves())
def get_roots(self) -> Generator[BaseOperator, None, None]:
"""
Returns a generator of tasks that are root tasks, i.e. those with no upstream
dependencies within the TaskGroup.
"""
for task in self:
if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)):
yield task
def get_leaves(self) -> Generator[BaseOperator, None, None]:
"""
Returns a generator of tasks that are leaf tasks, i.e. those with no downstream
dependencies within the TaskGroup.
"""
for task in self:
if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)):
yield task
def child_id(self, label):
"""
Prefix label with group_id if prefix_group_id is True. Otherwise return the label
as-is.
"""
if self.prefix_group_id:
group_id = self.group_id
if group_id:
return f"{group_id}.{label}"
return label
@property
def upstream_join_id(self) -> str:
"""
If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
upstream_join_id will be created in Graph view to join the outgoing edges from this
TaskGroup to reduce the total number of edges needed to be displayed.
"""
return f"{self.group_id}.upstream_join_id"
@property
def downstream_join_id(self) -> str:
"""
If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
downstream_join_id will be created in Graph view to join the outgoing edges from this
TaskGroup to reduce the total number of edges needed to be displayed.
"""
return f"{self.group_id}.downstream_join_id"
def get_task_group_dict(self) -> dict[str, TaskGroup]:
"""Returns a flat dictionary of group_id: TaskGroup."""
task_group_map = {}
def build_map(task_group):
if not isinstance(task_group, TaskGroup):
return
task_group_map[task_group.group_id] = task_group
for child in task_group.children.values():
build_map(child)
build_map(self)
return task_group_map
def get_child_by_label(self, label: str) -> DAGNode:
"""Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
return self.children[self.child_id(label)]
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Required by DAGNode."""
from airflow.serialization.serialized_objects import TaskGroupSerialization
return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self)
def topological_sort(self, _include_subdag_tasks: bool = False):
"""
Sorts children in topographical order, such that a task comes after any of its
upstream dependencies.
:return: list of tasks in topological order
"""
# This uses a modified version of Kahn's Topological Sort algorithm to
# not have to pre-compute the "in-degree" of the nodes.
from airflow.operators.subdag import SubDagOperator # Avoid circular import
graph_unsorted = copy.copy(self.children)
graph_sorted: list[DAGNode] = []
# special case
if len(self.children) == 0:
return graph_sorted
# Run until the unsorted graph is empty.
while graph_unsorted:
# Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
# any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
# pair from the unsorted graph, and append it to the sorted graph. Note here that by using using
# the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
# the unsorted graph as we move through it.
#
# We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
# during each pass through the graph. If not, we need to exit as the graph therefore can't be
# sorted.
acyclic = False
for node in list(graph_unsorted.values()):
for edge in node.upstream_list:
if edge.node_id in graph_unsorted:
break
# Check for task's group is a child (or grand child) of this TG,
tg = edge.task_group
while tg:
if tg.node_id in graph_unsorted:
break
tg = tg.task_group
if tg:
# We are already going to visit that TG
break
else:
acyclic = True
del graph_unsorted[node.node_id]
graph_sorted.append(node)
if _include_subdag_tasks and isinstance(node, SubDagOperator):
graph_sorted.extend(
node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
)
if not acyclic:
raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
return graph_sorted
def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""Return mapped task groups in the hierarchy.
Groups are returned from the closest to the outmost. If *self* is a
mapped task group, it is returned first.
:meta private:
"""
group: TaskGroup | None = self
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
group = group.task_group
def iter_tasks(self) -> Iterator[AbstractOperator]:
"""Returns an iterator of the child tasks."""
from airflow.models.abstractoperator import AbstractOperator
groups_to_visit = [self]
while groups_to_visit:
visiting = groups_to_visit.pop(0)
for child in visiting.children.values():
if isinstance(child, AbstractOperator):
yield child
elif isinstance(child, TaskGroup):
groups_to_visit.append(child)
else:
raise ValueError(
f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}"
)
class MappedTaskGroup(TaskGroup):
"""A mapped task group.
This doesn't really do anything special, just holds some additional metadata
for expansion later.
Don't instantiate this class directly; call *expand* or *expand_kwargs* on
a ``@task_group`` function instead.
"""
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input
for op, _ in expand_input.iter_references():
self.set_upstream(op)
def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this mapped task group."""
from airflow.models.xcom_arg import XComArg
for op, _ in XComArg.iter_xcom_references(self._expand_input):
yield op
@cache
def get_parse_time_mapped_ti_count(self) -> int:
"""Number of instances a task in this group should be mapped to, when a DAG run is created.
This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.
If this group is inside mapped task groups, all the nested counts are
multiplied and accounted.
:meta private:
:raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
:return: The total number of mapped instances each task should have.
"""
return functools.reduce(
operator.mul,
(g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
)
def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
"""Number of instances a task in this group should be mapped to at run time.
This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
return value should be identical to ``parse_time_mapped_ti_count`` if
all mapped arguments are literal.
If this group is inside mapped task groups, all the nested counts are
multiplied and accounted.
:meta private:
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:return: Total number of mapped TIs this task should have.
"""
groups = self.iter_mapped_task_groups()
return functools.reduce(
operator.mul,
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)
class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
_context_managed_task_group: TaskGroup | None = None
_previous_context_managed_task_groups: list[TaskGroup] = []
@classmethod
def push_context_managed_task_group(cls, task_group: TaskGroup):
"""Push a TaskGroup into the list of managed TaskGroups."""
if cls._context_managed_task_group:
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
cls._context_managed_task_group = task_group
@classmethod
def pop_context_managed_task_group(cls) -> TaskGroup | None:
"""Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup."""
old_task_group = cls._context_managed_task_group
if cls._previous_context_managed_task_groups:
cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
else:
cls._context_managed_task_group = None
return old_task_group
@classmethod
def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
"""Get the current TaskGroup."""
from airflow.models.dag import DagContext
if not cls._context_managed_task_group:
dag = dag or DagContext.get_current_dag()
if dag:
# If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
return dag.task_group
return cls._context_managed_task_group
def task_group_to_dict(task_item_or_group):
"""
Create a nested dict representation of this TaskGroup and its children used to construct
the Graph.
"""
from airflow.models.abstractoperator import AbstractOperator
if isinstance(task_item_or_group, AbstractOperator):
return {
"id": task_item_or_group.task_id,
"value": {
"label": task_item_or_group.label,
"labelStyle": f"fill:{task_item_or_group.ui_fgcolor};",
"style": f"fill:{task_item_or_group.ui_color};",
"rx": 5,
"ry": 5,
},
}
task_group = task_item_or_group
is_mapped = isinstance(task_group, MappedTaskGroup)
children = [
task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
]
if task_group.upstream_group_ids or task_group.upstream_task_ids:
children.append(
{
"id": task_group.upstream_join_id,
"value": {
"label": "",
"labelStyle": f"fill:{task_group.ui_fgcolor};",
"style": f"fill:{task_group.ui_color};",
"shape": "circle",
},
}
)
if task_group.downstream_group_ids or task_group.downstream_task_ids:
# This is the join node used to reduce the number of edges between two TaskGroup.
children.append(
{
"id": task_group.downstream_join_id,
"value": {
"label": "",
"labelStyle": f"fill:{task_group.ui_fgcolor};",
"style": f"fill:{task_group.ui_color};",
"shape": "circle",
},
}
)
return {
"id": task_group.group_id,
"value": {
"label": task_group.label,
"labelStyle": f"fill:{task_group.ui_fgcolor};",
"style": f"fill:{task_group.ui_color}",
"rx": 5,
"ry": 5,
"clusterLabelPos": "top",
"tooltip": task_group.tooltip,
"isMapped": is_mapped,
},
"children": children,
}