blob: 521c53d94ac8661b99cdd3245566d7da20a096bc [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union
from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.models.taskmixin import TaskMixin
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
class TaskGroup(TaskMixin):
"""
A collection of tasks. When set_downstream() or set_upstream() are called on the
TaskGroup, it is applied across all tasks within the group if necessary.
:param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
set to None.
:type group_id: str
:param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
Default is True.
:type prefix_group_id: bool
:param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
for the root TaskGroup.
:type parent_group: TaskGroup
:param dag: The DAG that this TaskGroup belongs to.
:type dag: airflow.models.DAG
:param tooltip: The tooltip of the TaskGroup node when displayed in the UI
:type tooltip: str
:param ui_color: The fill color of the TaskGroup node when displayed in the UI
:type ui_color: str
:param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
:type ui_fgcolor: str
"""
def __init__(
self,
group_id: Optional[str],
prefix_group_id: bool = True,
parent_group: Optional["TaskGroup"] = None,
dag: Optional["DAG"] = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
):
from airflow.models.dag import DagContext
self.prefix_group_id = prefix_group_id
if group_id is None:
# This creates a root TaskGroup.
if parent_group:
raise AirflowException("Root TaskGroup cannot have parent_group")
# used_group_ids is shared across all TaskGroups in the same DAG to keep track
# of used group_id to avoid duplication.
self.used_group_ids: Set[Optional[str]] = set()
self._parent_group = None
else:
if not isinstance(group_id, str):
raise ValueError("group_id must be str")
if not group_id:
raise ValueError("group_id must not be empty")
dag = dag or DagContext.get_current_dag()
if not parent_group and not dag:
raise AirflowException("TaskGroup can only be used inside a dag")
self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
if not self._parent_group:
raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
self.used_group_ids = self._parent_group.used_group_ids
self._group_id = group_id
if self.group_id in self.used_group_ids:
raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG")
self.used_group_ids.add(self.group_id)
self.used_group_ids.add(self.downstream_join_id)
self.used_group_ids.add(self.upstream_join_id)
self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {}
if self._parent_group:
self._parent_group.add(self)
self.tooltip = tooltip
self.ui_color = ui_color
self.ui_fgcolor = ui_fgcolor
# Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
# so that we can optimize the number of edges when entire TaskGroups depend on each other.
self.upstream_group_ids: Set[Optional[str]] = set()
self.downstream_group_ids: Set[Optional[str]] = set()
self.upstream_task_ids: Set[Optional[str]] = set()
self.downstream_task_ids: Set[Optional[str]] = set()
@classmethod
def create_root(cls, dag: "DAG") -> "TaskGroup":
"""Create a root TaskGroup with no group_id or parent."""
return cls(group_id=None, dag=dag)
@property
def is_root(self) -> bool:
"""Returns True if this TaskGroup is the root TaskGroup. Otherwise False"""
return not self.group_id
def __iter__(self):
for child in self.children.values():
if isinstance(child, TaskGroup):
yield from child
else:
yield child
def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None:
"""Add a task to this TaskGroup."""
key = task.group_id if isinstance(task, TaskGroup) else task.task_id
if key in self.children:
raise DuplicateTaskIdFound(f"Task id '{key}' has already been added to the DAG")
if isinstance(task, TaskGroup):
if task.children:
raise AirflowException("Cannot add a non-empty TaskGroup")
self.children[key] = task # type: ignore
@property
def group_id(self) -> Optional[str]:
"""group_id of this TaskGroup."""
if self._parent_group and self._parent_group.prefix_group_id and self._parent_group.group_id:
return self._parent_group.child_id(self._group_id)
return self._group_id
@property
def label(self) -> Optional[str]:
"""group_id excluding parent's group_id used as the node label in UI."""
return self._group_id
def update_relative(self, other: "TaskMixin", upstream=True) -> None:
"""
Overrides TaskMixin.update_relative.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
accordingly so that we can reduce the number of edges when displaying Graph View.
"""
from airflow.models.baseoperator import BaseOperator
if isinstance(other, TaskGroup):
# Handles setting relationship between a TaskGroup and another TaskGroup
if upstream:
parent, child = (self, other)
else:
parent, child = (other, self)
parent.upstream_group_ids.add(child.group_id)
child.downstream_group_ids.add(parent.group_id)
else:
# Handles setting relationship between a TaskGroup and a task
for task in other.roots:
if not isinstance(task, BaseOperator):
raise AirflowException(
"Relationships can only be set between TaskGroup "
f"or operators; received {task.__class__.__name__}"
)
if upstream:
self.upstream_task_ids.add(task.task_id)
else:
self.downstream_task_ids.add(task.task_id)
def _set_relative(
self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], upstream: bool = False
) -> None:
"""
Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
"""
if upstream:
for task in self.get_roots():
task.set_upstream(task_or_task_list)
else:
for task in self.get_leaves():
task.set_downstream(task_or_task_list)
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]
for task_like in task_or_task_list:
self.update_relative(task_like, upstream)
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""Set a TaskGroup/task/list of task downstream of this TaskGroup."""
self._set_relative(task_or_task_list, upstream=False)
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
"""Set a TaskGroup/task/list of task upstream of this TaskGroup."""
self._set_relative(task_or_task_list, upstream=True)
def __enter__(self):
TaskGroupContext.push_context_managed_task_group(self)
return self
def __exit__(self, _type, _value, _tb):
TaskGroupContext.pop_context_managed_task_group()
def has_task(self, task: "BaseOperator") -> bool:
"""Returns True if this TaskGroup or its children TaskGroups contains the given task."""
if task.task_id in self.children:
return True
return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
@property
def roots(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return list(self.get_roots())
@property
def leaves(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return list(self.get_leaves())
def get_roots(self) -> Generator["BaseOperator", None, None]:
"""
Returns a generator of tasks that are root tasks, i.e. those with no upstream
dependencies within the TaskGroup.
"""
for task in self:
if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)):
yield task
def get_leaves(self) -> Generator["BaseOperator", None, None]:
"""
Returns a generator of tasks that are leaf tasks, i.e. those with no downstream
dependencies within the TaskGroup
"""
for task in self:
if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)):
yield task
def child_id(self, label):
"""
Prefix label with group_id if prefix_group_id is True. Otherwise return the label
as-is.
"""
if self.prefix_group_id and self.group_id:
return f"{self.group_id}.{label}"
return label
@property
def upstream_join_id(self) -> str:
"""
If this TaskGroup has immediate upstream TaskGroups or tasks, a dummy node called
upstream_join_id will be created in Graph View to join the outgoing edges from this
TaskGroup to reduce the total number of edges needed to be displayed.
"""
return f"{self.group_id}.upstream_join_id"
@property
def downstream_join_id(self) -> str:
"""
If this TaskGroup has immediate downstream TaskGroups or tasks, a dummy node called
downstream_join_id will be created in Graph View to join the outgoing edges from this
TaskGroup to reduce the total number of edges needed to be displayed.
"""
return f"{self.group_id}.downstream_join_id"
def get_task_group_dict(self) -> Dict[str, "TaskGroup"]:
"""Returns a flat dictionary of group_id: TaskGroup"""
task_group_map = {}
def build_map(task_group):
if not isinstance(task_group, TaskGroup):
return
task_group_map[task_group.group_id] = task_group
for child in task_group.children.values():
build_map(child)
build_map(self)
return task_group_map
def get_child_by_label(self, label: str) -> Union["BaseOperator", "TaskGroup"]:
"""Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)"""
return self.children[self.child_id(label)]
class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
_context_managed_task_group: Optional[TaskGroup] = None
_previous_context_managed_task_groups: List[TaskGroup] = []
@classmethod
def push_context_managed_task_group(cls, task_group: TaskGroup):
"""Push a TaskGroup into the list of managed TaskGroups."""
if cls._context_managed_task_group:
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
cls._context_managed_task_group = task_group
@classmethod
def pop_context_managed_task_group(cls) -> Optional[TaskGroup]:
"""Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup."""
old_task_group = cls._context_managed_task_group
if cls._previous_context_managed_task_groups:
cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
else:
cls._context_managed_task_group = None
return old_task_group
@classmethod
def get_current_task_group(cls, dag: Optional["DAG"]) -> Optional[TaskGroup]:
"""Get the current TaskGroup."""
from airflow.models.dag import DagContext
if not cls._context_managed_task_group:
dag = dag or DagContext.get_current_dag()
if dag:
# If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
return dag.task_group
return cls._context_managed_task_group