blob: 85ec1eb0d3b09e2561eadc27f43f148288fd4e86 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
from __future__ import annotations
import copy
import functools
import operator
import re
import weakref
from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence
from airflow.compat.functools import cache
from airflow.exceptions import (
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
from sqlalchemy.orm import Session
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.expandinput import ExpandInput
from airflow.models.operator import Operator
from airflow.utils.edgemodifier import EdgeModifier
class TaskGroup(DAGNode):
A collection of tasks. When set_downstream() or set_upstream() are called on the
TaskGroup, it is applied across all tasks within the group if necessary.
:param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
set to None.
:param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
Default is True.
:param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
for the root TaskGroup.
:param dag: The DAG that this TaskGroup belongs to.
:param default_args: A dictionary of default parameters to be used
as constructor keyword parameters when initialising operators,
will override default_args defined in the DAG level.
Note that operators have the same hook, and precede those defined
here, meaning that if your dict contains `'depends_on_past': True`
here and `'depends_on_past': False` in the operator's call
`default_args`, the actual value will be `False`.
:param tooltip: The tooltip of the TaskGroup node when displayed in the UI
:param ui_color: The fill color of the TaskGroup node when displayed in the UI
:param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
:param add_suffix_on_collision: If this task group name already exists,
automatically add `__1` etc suffixes
used_group_ids: set[str | None]
def __init__(
group_id: str | None,
prefix_group_id: bool = True,
parent_group: TaskGroup | None = None,
dag: DAG | None = None,
default_args: dict[str, Any] | None = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
from airflow.models.dag import DagContext
self.prefix_group_id = prefix_group_id
self.default_args = copy.deepcopy(default_args or {})
dag = dag or DagContext.get_current_dag()
if group_id is None:
# This creates a root TaskGroup.
if parent_group:
raise AirflowException("Root TaskGroup cannot have parent_group")
# used_group_ids is shared across all TaskGroups in the same DAG to keep track
# of used group_id to avoid duplication.
self.used_group_ids = set()
self.dag = dag
if prefix_group_id:
# If group id is used as prefix, it should not contain spaces nor dots
# because it is used as prefix in the task_id
if not isinstance(group_id, str):
raise ValueError("group_id must be str")
if not group_id:
raise ValueError("group_id must not be empty")
if not parent_group and not dag:
raise AirflowException("TaskGroup can only be used inside a dag")
parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
if not parent_group:
raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
if dag is not parent_group.dag:
raise RuntimeError(
"Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
self.used_group_ids = parent_group.used_group_ids
# if given group_id already used assign suffix by incrementing largest used suffix integer
# Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
self._group_id = group_id
self.children: dict[str, DAGNode] = {}
if parent_group:
if self.group_id:
self.tooltip = tooltip
self.ui_color = ui_color
self.ui_fgcolor = ui_fgcolor
# Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
# so that we can optimize the number of edges when entire TaskGroups depend on each other.
self.upstream_group_ids: set[str | None] = set()
self.downstream_group_ids: set[str | None] = set()
self.upstream_task_ids = set()
self.downstream_task_ids = set()
def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
if self._group_id is None:
# if given group_id already used assign suffix by incrementing largest used suffix integer
# Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
if self._group_id in self.used_group_ids:
if not add_suffix_on_collision:
raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
base = re.split(r"__\d+$", self._group_id)[0]
suffixes = sorted(
int(re.split(r"^.+__", used_group_id)[1])
for used_group_id in self.used_group_ids
if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
if not suffixes:
self._group_id += "__1"
self._group_id = f"{base}__{suffixes[-1] + 1}"
def create_root(cls, dag: DAG) -> TaskGroup:
"""Create a root TaskGroup with no group_id or parent."""
return cls(group_id=None, dag=dag)
def node_id(self):
return self.group_id
def is_root(self) -> bool:
"""Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
return not self.group_id
def parent_group(self) -> TaskGroup | None:
return self.task_group
def __iter__(self):
for child in self.children.values():
if isinstance(child, TaskGroup):
yield from child
yield child
def add(self, task: DAGNode) -> None:
"""Add a task to this TaskGroup.
:meta private:
from airflow.models.abstractoperator import AbstractOperator
existing_tg = task.task_group
if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
# Set the TG first, as setting it might change the return value of node_id!
task.task_group = weakref.proxy(self)
key = task.node_id
if key in self.children:
node_type = "Task" if hasattr(task, "task_id") else "Task Group"
raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")
if isinstance(task, TaskGroup):
if self.dag:
if task.dag is not None and self.dag is not task.dag:
raise RuntimeError(
"Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag
task.dag = self.dag
if task.children:
raise AirflowException("Cannot add a non-empty TaskGroup")
self.children[key] = task
def _remove(self, task: DAGNode) -> None:
key = task.node_id
if key not in self.children:
raise KeyError(f"Node id {key!r} not part of this task group")
del self.children[key]
def group_id(self) -> str | None:
"""group_id of this TaskGroup."""
if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id:
# defer to parent whether it adds a prefix
return self.task_group.child_id(self._group_id)
return self._group_id
def label(self) -> str | None:
"""group_id excluding parent's group_id used as the node label in UI."""
return self._group_id
def update_relative(
self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
) -> None:
Overrides TaskMixin.update_relative.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
accordingly so that we can reduce the number of edges when displaying Graph view.
if isinstance(other, TaskGroup):
# Handles setting relationship between a TaskGroup and another TaskGroup
if upstream:
parent, child = (self, other)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
parent, child = (other, self)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
# Handles setting relationship between a TaskGroup and a task
for task in other.roots:
if not isinstance(task, DAGNode):
raise AirflowException(
"Relationships can only be set between TaskGroup "
f"or operators; received {task.__class__.__name__}"
# Do not set a relationship between a TaskGroup and a Label's roots
if self == task:
if upstream:
if edge_modifier:
edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
def _set_relatives(
task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
upstream: bool = False,
edge_modifier: EdgeModifier | None = None,
) -> None:
Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]
for task_like in task_or_task_list:
self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
if upstream:
for task in self.get_roots():
for task in self.get_leaves():
def __enter__(self) -> TaskGroup:
return self
def __exit__(self, _type, _value, _tb):
def has_task(self, task: BaseOperator) -> bool:
"""Returns True if this TaskGroup or its children TaskGroups contains the given task."""
if task.task_id in self.children:
return True
return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
def roots(self) -> list[BaseOperator]:
"""Required by TaskMixin."""
return list(self.get_roots())
def leaves(self) -> list[BaseOperator]:
"""Required by TaskMixin."""
return list(self.get_leaves())
def get_roots(self) -> Generator[BaseOperator, None, None]:
Returns a generator of tasks that are root tasks, i.e. those with no upstream
dependencies within the TaskGroup.
for task in self:
if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)):
yield task
def get_leaves(self) -> Generator[BaseOperator, None, None]:
Returns a generator of tasks that are leaf tasks, i.e. those with no downstream
dependencies within the TaskGroup.
for task in self:
if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)):
yield task
def child_id(self, label):
Prefix label with group_id if prefix_group_id is True. Otherwise return the label
if self.prefix_group_id:
group_id = self.group_id
if group_id:
return f"{group_id}.{label}"
return label
def upstream_join_id(self) -> str:
If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
upstream_join_id will be created in Graph view to join the outgoing edges from this
TaskGroup to reduce the total number of edges needed to be displayed.
return f"{self.group_id}.upstream_join_id"
def downstream_join_id(self) -> str:
If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
downstream_join_id will be created in Graph view to join the outgoing edges from this
TaskGroup to reduce the total number of edges needed to be displayed.
return f"{self.group_id}.downstream_join_id"
def get_task_group_dict(self) -> dict[str, TaskGroup]:
"""Returns a flat dictionary of group_id: TaskGroup."""
task_group_map = {}
def build_map(task_group):
if not isinstance(task_group, TaskGroup):
task_group_map[task_group.group_id] = task_group
for child in task_group.children.values():
return task_group_map
def get_child_by_label(self, label: str) -> DAGNode:
"""Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
return self.children[self.child_id(label)]
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Required by DAGNode."""
from airflow.serialization.serialized_objects import TaskGroupSerialization
return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self)
def topological_sort(self, _include_subdag_tasks: bool = False):
Sorts children in topographical order, such that a task comes after any of its
upstream dependencies.
:return: list of tasks in topological order
# This uses a modified version of Kahn's Topological Sort algorithm to
# not have to pre-compute the "in-degree" of the nodes.
from airflow.operators.subdag import SubDagOperator # Avoid circular import
graph_unsorted = copy.copy(self.children)
graph_sorted: list[DAGNode] = []
# special case
if len(self.children) == 0:
return graph_sorted
# Run until the unsorted graph is empty.
while graph_unsorted:
# Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
# any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
# pair from the unsorted graph, and append it to the sorted graph. Note here that by using using
# the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
# the unsorted graph as we move through it.
# We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
# during each pass through the graph. If not, we need to exit as the graph therefore can't be
# sorted.
acyclic = False
for node in list(graph_unsorted.values()):
for edge in node.upstream_list:
if edge.node_id in graph_unsorted:
# Check for task's group is a child (or grand child) of this TG,
tg = edge.task_group
while tg:
if tg.node_id in graph_unsorted:
tg = tg.task_group
if tg:
# We are already going to visit that TG
acyclic = True
del graph_unsorted[node.node_id]
if _include_subdag_tasks and isinstance(node, SubDagOperator):
if not acyclic:
raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
return graph_sorted
def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""Return mapped task groups in the hierarchy.
Groups are returned from the closest to the outmost. If *self* is a
mapped task group, it is returned first.
:meta private:
group: TaskGroup | None = self
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
group = group.task_group
def iter_tasks(self) -> Iterator[AbstractOperator]:
"""Returns an iterator of the child tasks."""
from airflow.models.abstractoperator import AbstractOperator
groups_to_visit = [self]
while groups_to_visit:
visiting = groups_to_visit.pop(0)
for child in visiting.children.values():
if isinstance(child, AbstractOperator):
yield child
elif isinstance(child, TaskGroup):
raise ValueError(
f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}"
class MappedTaskGroup(TaskGroup):
"""A mapped task group.
This doesn't really do anything special, just holds some additional metadata
for expansion later.
Don't instantiate this class directly; call *expand* or *expand_kwargs* on
a ``@task_group`` function instead.
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
self._expand_input = expand_input
for op, _ in expand_input.iter_references():
def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this mapped task group."""
from airflow.models.xcom_arg import XComArg
for op, _ in XComArg.iter_xcom_references(self._expand_input):
yield op
def get_parse_time_mapped_ti_count(self) -> int:
"""Number of instances a task in this group should be mapped to, when a DAG run is created.
This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.
If this group is inside mapped task groups, all the nested counts are
multiplied and accounted.
:meta private:
:raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
:return: The total number of mapped instances each task should have.
return functools.reduce(
(g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
"""Number of instances a task in this group should be mapped to at run time.
This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
return value should be identical to ``parse_time_mapped_ti_count`` if
all mapped arguments are literal.
If this group is inside mapped task groups, all the nested counts are
multiplied and accounted.
:meta private:
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:return: Total number of mapped TIs this task should have.
groups = self.iter_mapped_task_groups()
return functools.reduce(
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
_context_managed_task_group: TaskGroup | None = None
_previous_context_managed_task_groups: list[TaskGroup] = []
def push_context_managed_task_group(cls, task_group: TaskGroup):
"""Push a TaskGroup into the list of managed TaskGroups."""
if cls._context_managed_task_group:
cls._context_managed_task_group = task_group
def pop_context_managed_task_group(cls) -> TaskGroup | None:
"""Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup."""
old_task_group = cls._context_managed_task_group
if cls._previous_context_managed_task_groups:
cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
cls._context_managed_task_group = None
return old_task_group
def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
"""Get the current TaskGroup."""
from airflow.models.dag import DagContext
if not cls._context_managed_task_group:
dag = dag or DagContext.get_current_dag()
if dag:
# If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
return dag.task_group
return cls._context_managed_task_group
def task_group_to_dict(task_item_or_group):
Create a nested dict representation of this TaskGroup and its children used to construct
the Graph.
from airflow.models.abstractoperator import AbstractOperator
if isinstance(task_item_or_group, AbstractOperator):
return {
"id": task_item_or_group.task_id,
"value": {
"label": task_item_or_group.label,
"labelStyle": f"fill:{task_item_or_group.ui_fgcolor};",
"style": f"fill:{task_item_or_group.ui_color};",
"rx": 5,
"ry": 5,
task_group = task_item_or_group
is_mapped = isinstance(task_group, MappedTaskGroup)
children = [
task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
if task_group.upstream_group_ids or task_group.upstream_task_ids:
"id": task_group.upstream_join_id,
"value": {
"label": "",
"labelStyle": f"fill:{task_group.ui_fgcolor};",
"style": f"fill:{task_group.ui_color};",
"shape": "circle",
if task_group.downstream_group_ids or task_group.downstream_task_ids:
# This is the join node used to reduce the number of edges between two TaskGroup.
"id": task_group.downstream_join_id,
"value": {
"label": "",
"labelStyle": f"fill:{task_group.ui_fgcolor};",
"style": f"fill:{task_group.ui_color};",
"shape": "circle",
return {
"id": task_group.group_id,
"value": {
"label": task_group.label,
"labelStyle": f"fill:{task_group.ui_fgcolor};",
"style": f"fill:{task_group.ui_color}",
"rx": 5,
"ry": 5,
"clusterLabelPos": "top",
"tooltip": task_group.tooltip,
"isMapped": is_mapped,
"children": children,