blob: 359980ec0956f4254ae955dddd2df556d288eb34 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import timedelta
import pendulum
import pytest
from airflow.decorators import (
dag,
setup,
task as task_decorator,
task_group as task_group_decorator,
teardown,
)
from airflow.exceptions import TaskAlreadyInTaskGroup
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.utils.dag_edges import dag_edges
from airflow.utils.task_group import TaskGroup, task_group_to_dict
from tests.models import DEFAULT_DATE
def make_task(name, type_="classic"):
if type_ == "classic":
return BashOperator(task_id=name, bash_command="echo 1")
else:
@task_decorator
def my_task():
pass
return my_task.override(task_id=name)()
EXPECTED_JSON = {
"id": None,
"value": {
"label": None,
"labelStyle": "fill:#000;",
"style": "fill:CornflowerBlue",
"rx": 5,
"ry": 5,
"clusterLabelPos": "top",
"isMapped": False,
"tooltip": "",
},
"children": [
{
"id": "group234",
"value": {
"label": "group234",
"labelStyle": "fill:#000;",
"style": "fill:CornflowerBlue",
"rx": 5,
"ry": 5,
"clusterLabelPos": "top",
"tooltip": "",
"isMapped": False,
},
"children": [
{
"id": "group234.group34",
"value": {
"label": "group34",
"labelStyle": "fill:#000;",
"style": "fill:CornflowerBlue",
"rx": 5,
"ry": 5,
"clusterLabelPos": "top",
"tooltip": "",
"isMapped": False,
},
"children": [
{
"id": "group234.group34.task3",
"value": {
"label": "task3",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "group234.group34.task4",
"value": {
"label": "task4",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "group234.group34.downstream_join_id",
"value": {
"label": "",
"labelStyle": "fill:#000;",
"style": "fill:CornflowerBlue;",
"shape": "circle",
},
},
],
},
{
"id": "group234.task2",
"value": {
"label": "task2",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "group234.upstream_join_id",
"value": {
"label": "",
"labelStyle": "fill:#000;",
"style": "fill:CornflowerBlue;",
"shape": "circle",
},
},
],
},
{
"id": "task1",
"value": {
"label": "task1",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "task5",
"value": {
"label": "task5",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
],
}
def test_build_task_group_context_manager():
execution_date = pendulum.parse("20200101")
with DAG("test_build_task_group_context_manager", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
with TaskGroup("group234") as group234:
_ = EmptyOperator(task_id="task2")
with TaskGroup("group34") as group34:
_ = EmptyOperator(task_id="task3")
_ = EmptyOperator(task_id="task4")
task5 = EmptyOperator(task_id="task5")
task1 >> group234
group34 >> task5
assert task1.get_direct_relative_ids(upstream=False) == {
"group234.group34.task4",
"group234.group34.task3",
"group234.task2",
}
assert task5.get_direct_relative_ids(upstream=True) == {
"group234.group34.task4",
"group234.group34.task3",
}
assert dag.task_group.group_id is None
assert dag.task_group.is_root
assert set(dag.task_group.children.keys()) == {"task1", "group234", "task5"}
assert group34.group_id == "group234.group34"
assert task_group_to_dict(dag.task_group) == EXPECTED_JSON
def test_build_task_group():
"""
This is an alternative syntax to use TaskGroup. It should result in the same TaskGroup
as using context manager.
"""
execution_date = pendulum.parse("20200101")
dag = DAG("test_build_task_group", start_date=execution_date)
task1 = EmptyOperator(task_id="task1", dag=dag)
group234 = TaskGroup("group234", dag=dag)
_ = EmptyOperator(task_id="task2", dag=dag, task_group=group234)
group34 = TaskGroup("group34", dag=dag, parent_group=group234)
_ = EmptyOperator(task_id="task3", dag=dag, task_group=group34)
_ = EmptyOperator(task_id="task4", dag=dag, task_group=group34)
task5 = EmptyOperator(task_id="task5", dag=dag)
task1 >> group234
group34 >> task5
assert task_group_to_dict(dag.task_group) == EXPECTED_JSON
def extract_node_id(node, include_label=False):
ret = {"id": node["id"]}
if include_label:
ret["label"] = node["value"]["label"]
if "children" in node:
children = []
for child in node["children"]:
children.append(extract_node_id(child, include_label=include_label))
ret["children"] = children
return ret
def test_build_task_group_with_prefix():
"""
Tests that prefix_group_id turns on/off prefixing of task_id with group_id.
"""
execution_date = pendulum.parse("20200101")
with DAG("test_build_task_group_with_prefix", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
with TaskGroup("group234", prefix_group_id=False) as group234:
task2 = EmptyOperator(task_id="task2")
with TaskGroup("group34") as group34:
task3 = EmptyOperator(task_id="task3")
with TaskGroup("group4", prefix_group_id=False) as group4:
task4 = EmptyOperator(task_id="task4")
task5 = EmptyOperator(task_id="task5")
task1 >> group234
group34 >> task5
assert task2.task_id == "task2"
assert group34.group_id == "group34"
assert task3.task_id == "group34.task3"
assert group4.group_id == "group34.group4"
assert task4.task_id == "task4"
assert task5.task_id == "task5"
assert group234.get_child_by_label("task2") == task2
assert group234.get_child_by_label("group34") == group34
assert group4.get_child_by_label("task4") == task4
assert extract_node_id(task_group_to_dict(dag.task_group), include_label=True) == {
"id": None,
"label": None,
"children": [
{
"id": "group234",
"label": "group234",
"children": [
{
"id": "group34",
"label": "group34",
"children": [
{
"id": "group34.group4",
"label": "group4",
"children": [{"id": "task4", "label": "task4"}],
},
{"id": "group34.task3", "label": "task3"},
{"id": "group34.downstream_join_id", "label": ""},
],
},
{"id": "task2", "label": "task2"},
{"id": "group234.upstream_join_id", "label": ""},
],
},
{"id": "task1", "label": "task1"},
{"id": "task5", "label": "task5"},
],
}
def test_build_task_group_with_task_decorator():
"""
Test that TaskGroup can be used with the @task decorator.
"""
from airflow.decorators import task
@task
def task_1():
print("task_1")
@task
def task_2():
return "task_2"
@task
def task_3():
return "task_3"
@task
def task_4(task_2_output, task_3_output):
print(task_2_output, task_3_output)
@task
def task_5():
print("task_5")
execution_date = pendulum.parse("20200101")
with DAG("test_build_task_group_with_task_decorator", start_date=execution_date) as dag:
tsk_1 = task_1()
with TaskGroup("group234") as group234:
tsk_2 = task_2()
tsk_3 = task_3()
tsk_4 = task_4(tsk_2, tsk_3)
tsk_5 = task_5()
tsk_1 >> group234 >> tsk_5
assert tsk_1.operator in tsk_2.operator.upstream_list
assert tsk_1.operator in tsk_3.operator.upstream_list
assert tsk_5.operator in tsk_4.operator.downstream_list
assert extract_node_id(task_group_to_dict(dag.task_group)) == {
"id": None,
"children": [
{
"id": "group234",
"children": [
{"id": "group234.task_2"},
{"id": "group234.task_3"},
{"id": "group234.task_4"},
{"id": "group234.upstream_join_id"},
{"id": "group234.downstream_join_id"},
],
},
{"id": "task_1"},
{"id": "task_5"},
],
}
edges = dag_edges(dag)
assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
("group234.downstream_join_id", "task_5"),
("group234.task_2", "group234.task_4"),
("group234.task_3", "group234.task_4"),
("group234.task_4", "group234.downstream_join_id"),
("group234.upstream_join_id", "group234.task_2"),
("group234.upstream_join_id", "group234.task_3"),
("task_1", "group234.upstream_join_id"),
]
def test_sub_dag_task_group():
"""
Tests dag.partial_subset() updates task_group correctly.
"""
execution_date = pendulum.parse("20200101")
with DAG("test_test_task_group_sub_dag", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
with TaskGroup("group234") as group234:
_ = EmptyOperator(task_id="task2")
with TaskGroup("group34") as group34:
_ = EmptyOperator(task_id="task3")
_ = EmptyOperator(task_id="task4")
with TaskGroup("group6") as group6:
_ = EmptyOperator(task_id="task6")
task7 = EmptyOperator(task_id="task7")
task5 = EmptyOperator(task_id="task5")
task1 >> group234
group34 >> task5
group234 >> group6
group234 >> task7
subdag = dag.partial_subset(task_ids_or_regex="task5", include_upstream=True, include_downstream=False)
assert extract_node_id(task_group_to_dict(subdag.task_group)) == {
"id": None,
"children": [
{
"id": "group234",
"children": [
{
"id": "group234.group34",
"children": [
{"id": "group234.group34.task3"},
{"id": "group234.group34.task4"},
{"id": "group234.group34.downstream_join_id"},
],
},
{"id": "group234.upstream_join_id"},
],
},
{"id": "task1"},
{"id": "task5"},
],
}
edges = dag_edges(subdag)
assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
("group234.group34.downstream_join_id", "task5"),
("group234.group34.task3", "group234.group34.downstream_join_id"),
("group234.group34.task4", "group234.group34.downstream_join_id"),
("group234.upstream_join_id", "group234.group34.task3"),
("group234.upstream_join_id", "group234.group34.task4"),
("task1", "group234.upstream_join_id"),
]
subdag_task_groups = subdag.task_group.get_task_group_dict()
assert subdag_task_groups.keys() == {None, "group234", "group234.group34"}
included_group_ids = {"group234", "group234.group34"}
included_task_ids = {"group234.group34.task3", "group234.group34.task4", "task1", "task5"}
for task_group in subdag_task_groups.values():
assert task_group.upstream_group_ids.issubset(included_group_ids)
assert task_group.downstream_group_ids.issubset(included_group_ids)
assert task_group.upstream_task_ids.issubset(included_task_ids)
assert task_group.downstream_task_ids.issubset(included_task_ids)
for task in subdag.task_group:
assert task.upstream_task_ids.issubset(included_task_ids)
assert task.downstream_task_ids.issubset(included_task_ids)
def test_dag_edges():
execution_date = pendulum.parse("20200101")
with DAG("test_dag_edges", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
with TaskGroup("group_a") as group_a:
with TaskGroup("group_b") as group_b:
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
task4 = EmptyOperator(task_id="task4")
task2 >> [task3, task4]
task5 = EmptyOperator(task_id="task5")
task5 << group_b
task1 >> group_a
with TaskGroup("group_c") as group_c:
task6 = EmptyOperator(task_id="task6")
task7 = EmptyOperator(task_id="task7")
task8 = EmptyOperator(task_id="task8")
[task6, task7] >> task8
group_a >> group_c
task5 >> task8
task9 = EmptyOperator(task_id="task9")
task10 = EmptyOperator(task_id="task10")
group_c >> [task9, task10]
with TaskGroup("group_d") as group_d:
task11 = EmptyOperator(task_id="task11")
task12 = EmptyOperator(task_id="task12")
task11 >> task12
group_d << group_c
nodes = task_group_to_dict(dag.task_group)
edges = dag_edges(dag)
assert extract_node_id(nodes) == {
"id": None,
"children": [
{
"id": "group_a",
"children": [
{
"id": "group_a.group_b",
"children": [
{"id": "group_a.group_b.task2"},
{"id": "group_a.group_b.task3"},
{"id": "group_a.group_b.task4"},
{"id": "group_a.group_b.downstream_join_id"},
],
},
{"id": "group_a.task5"},
{"id": "group_a.upstream_join_id"},
{"id": "group_a.downstream_join_id"},
],
},
{
"id": "group_c",
"children": [
{"id": "group_c.task6"},
{"id": "group_c.task7"},
{"id": "group_c.task8"},
{"id": "group_c.upstream_join_id"},
{"id": "group_c.downstream_join_id"},
],
},
{
"id": "group_d",
"children": [
{"id": "group_d.task11"},
{"id": "group_d.task12"},
{"id": "group_d.upstream_join_id"},
],
},
{"id": "task1"},
{"id": "task10"},
{"id": "task9"},
],
}
assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
("group_a.downstream_join_id", "group_c.upstream_join_id"),
("group_a.group_b.downstream_join_id", "group_a.task5"),
("group_a.group_b.task2", "group_a.group_b.task3"),
("group_a.group_b.task2", "group_a.group_b.task4"),
("group_a.group_b.task3", "group_a.group_b.downstream_join_id"),
("group_a.group_b.task4", "group_a.group_b.downstream_join_id"),
("group_a.task5", "group_a.downstream_join_id"),
("group_a.task5", "group_c.task8"),
("group_a.upstream_join_id", "group_a.group_b.task2"),
("group_c.downstream_join_id", "group_d.upstream_join_id"),
("group_c.downstream_join_id", "task10"),
("group_c.downstream_join_id", "task9"),
("group_c.task6", "group_c.task8"),
("group_c.task7", "group_c.task8"),
("group_c.task8", "group_c.downstream_join_id"),
("group_c.upstream_join_id", "group_c.task6"),
("group_c.upstream_join_id", "group_c.task7"),
("group_d.task11", "group_d.task12"),
("group_d.upstream_join_id", "group_d.task11"),
("task1", "group_a.upstream_join_id"),
]
def test_dag_edges_setup_teardown():
execution_date = pendulum.parse("20200101")
with DAG("test_dag_edges", start_date=execution_date) as dag:
setup1 = EmptyOperator(task_id="setup1").as_setup()
teardown1 = EmptyOperator(task_id="teardown1").as_teardown()
with setup1 >> teardown1:
EmptyOperator(task_id="task1")
with TaskGroup("group_a"):
setup2 = EmptyOperator(task_id="setup2").as_setup()
teardown2 = EmptyOperator(task_id="teardown2").as_teardown()
with setup2 >> teardown2:
EmptyOperator(task_id="task2")
edges = dag_edges(dag)
assert sorted((e["source_id"], e["target_id"], e.get("is_setup_teardown")) for e in edges) == [
("group_a.setup2", "group_a.task2", None),
("group_a.setup2", "group_a.teardown2", True),
("group_a.task2", "group_a.teardown2", None),
("setup1", "task1", None),
("setup1", "teardown1", True),
("task1", "teardown1", None),
]
def test_dag_edges_setup_teardown_nested():
from airflow.decorators import task, task_group
from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator
execution_date = pendulum.parse("20200101")
with DAG(dag_id="s_t_dag", start_date=execution_date) as dag:
@task
def test_task():
print("Hello world!")
@task_group
def inner():
inner_start = EmptyOperator(task_id="start")
inner_end = EmptyOperator(task_id="end")
test_task_r = test_task.override(task_id="work")()
inner_start >> test_task_r >> inner_end.as_teardown(setups=inner_start)
@task_group
def outer():
outer_work = EmptyOperator(task_id="work")
inner_group = inner()
inner_group >> outer_work
dag_start = EmptyOperator(task_id="dag_start")
dag_end = EmptyOperator(task_id="dag_end")
dag_start >> outer() >> dag_end
edges = dag_edges(dag)
actual = sorted((e["source_id"], e["target_id"], e.get("is_setup_teardown")) for e in edges)
assert actual == [
("dag_start", "outer.upstream_join_id", None),
("outer.downstream_join_id", "dag_end", None),
("outer.inner.downstream_join_id", "outer.work", None),
("outer.inner.start", "outer.inner.end", True),
("outer.inner.start", "outer.inner.work", None),
("outer.inner.work", "outer.inner.downstream_join_id", None),
("outer.inner.work", "outer.inner.end", None),
("outer.upstream_join_id", "outer.inner.start", None),
("outer.work", "outer.downstream_join_id", None),
]
def test_duplicate_group_id():
from airflow.exceptions import DuplicateTaskIdFound
execution_date = pendulum.parse("20200101")
with DAG("test_duplicate_group_id", start_date=execution_date):
_ = EmptyOperator(task_id="task1")
with pytest.raises(DuplicateTaskIdFound, match=r".* 'task1' .*"), TaskGroup("task1"):
pass
with DAG("test_duplicate_group_id", start_date=execution_date):
_ = EmptyOperator(task_id="task1")
with TaskGroup("group1", prefix_group_id=False):
with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"), TaskGroup("group1"):
pass
with DAG("test_duplicate_group_id", start_date=execution_date):
with TaskGroup("group1", prefix_group_id=False):
with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"):
_ = EmptyOperator(task_id="group1")
with DAG("test_duplicate_group_id", start_date=execution_date):
_ = EmptyOperator(task_id="task1")
with TaskGroup("group1"):
with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.downstream_join_id' .*"):
_ = EmptyOperator(task_id="downstream_join_id")
with DAG("test_duplicate_group_id", start_date=execution_date):
_ = EmptyOperator(task_id="task1")
with TaskGroup("group1"):
with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.upstream_join_id' .*"):
_ = EmptyOperator(task_id="upstream_join_id")
def test_task_without_dag():
"""
Test that if a task doesn't have a DAG when it's being set as the relative of another task which
has a DAG, the task should be added to the root TaskGroup of the other task's DAG.
"""
dag = DAG(dag_id="test_task_without_dag", start_date=pendulum.parse("20200101"))
op1 = EmptyOperator(task_id="op1", dag=dag)
op2 = EmptyOperator(task_id="op2")
op3 = EmptyOperator(task_id="op3")
op1 >> op2
op3 >> op2
assert op1.dag == op2.dag == op3.dag
assert dag.task_group.children.keys() == {"op1", "op2", "op3"}
assert dag.task_group.children.keys() == dag.task_dict.keys()
# taskgroup decorator tests
def test_build_task_group_deco_context_manager():
"""
Tests Following :
1. Nested TaskGroup creation using taskgroup decorator should create same TaskGroup which can be
created using TaskGroup context manager.
2. TaskGroup consisting Tasks created using task decorator.
3. Node Ids of dags created with taskgroup decorator.
"""
from airflow.decorators import task
# Creating Tasks
@task
def task_start():
"""Dummy Task which is First Task of Dag"""
return "[Task_start]"
@task
def task_end():
"""Dummy Task which is Last Task of Dag"""
print("[ Task_End ]")
@task
def task_1(value):
"""Dummy Task1"""
return f"[ Task1 {value} ]"
@task
def task_2(value):
"""Dummy Task2"""
print(f"[ Task2 {value} ]")
@task
def task_3(value):
"""Dummy Task3"""
return f"[ Task3 {value} ]"
@task
def task_4(value):
"""Dummy Task3"""
print(f"[ Task4 {value} ]")
# Creating TaskGroups
@task_group_decorator
def section_1(value):
"""TaskGroup for grouping related Tasks"""
@task_group_decorator()
def section_2(value2):
"""TaskGroup for grouping related Tasks"""
return task_4(task_3(value2))
op1 = task_2(task_1(value))
return section_2(op1)
execution_date = pendulum.parse("20201109")
with DAG(
dag_id="example_nested_task_group_decorator", start_date=execution_date, tags=["example"]
) as dag:
t_start = task_start()
sec_1 = section_1(t_start)
sec_1.set_downstream(task_end())
# Testing TaskGroup created using taskgroup decorator
assert set(dag.task_group.children.keys()) == {"task_start", "task_end", "section_1"}
assert set(dag.task_group.children["section_1"].children.keys()) == {
"section_1.task_1",
"section_1.task_2",
"section_1.section_2",
}
# Testing TaskGroup consisting Tasks created using task decorator
assert dag.task_dict["task_start"].downstream_task_ids == {"section_1.task_1"}
assert dag.task_dict["section_1.task_2"].downstream_task_ids == {"section_1.section_2.task_3"}
assert dag.task_dict["section_1.section_2.task_4"].downstream_task_ids == {"task_end"}
# Node IDs test
node_ids = {
"id": None,
"children": [
{
"id": "section_1",
"children": [
{
"id": "section_1.section_2",
"children": [
{"id": "section_1.section_2.task_3"},
{"id": "section_1.section_2.task_4"},
],
},
{"id": "section_1.task_1"},
{"id": "section_1.task_2"},
],
},
{"id": "task_end"},
{"id": "task_start"},
],
}
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
def test_build_task_group_depended_by_task():
"""A decorator-based task group should be able to be used as a relative to operators."""
from airflow.decorators import dag as dag_decorator, task
@dag_decorator(start_date=pendulum.now())
def build_task_group_depended_by_task():
@task
def task_start():
return "[Task_start]"
@task
def task_end():
return "[Task_end]"
@task
def task_thing(value):
return f"[Task_thing {value}]"
@task_group_decorator
def section_1():
task_thing(1)
task_thing(2)
task_start() >> section_1() >> task_end()
dag = build_task_group_depended_by_task()
task_thing_1 = dag.task_dict["section_1.task_thing"]
task_thing_2 = dag.task_dict["section_1.task_thing__1"]
# Tasks in the task group don't depend on each other; they both become
# downstreams to task_start, and upstreams to task_end.
assert task_thing_1.upstream_task_ids == task_thing_2.upstream_task_ids == {"task_start"}
assert task_thing_1.downstream_task_ids == task_thing_2.downstream_task_ids == {"task_end"}
def test_build_task_group_with_operators():
"""Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator"""
from airflow.decorators import task
def task_start():
"""Dummy Task which is First Task of Dag"""
return "[Task_start]"
def task_end():
"""Dummy Task which is Last Task of Dag"""
print("[ Task_End ]")
# Creating Tasks
@task
def task_1(value):
"""Dummy Task1"""
return f"[ Task1 {value} ]"
@task
def task_2(value):
"""Dummy Task2"""
return f"[ Task2 {value} ]"
@task
def task_3(value):
"""Dummy Task3"""
print(f"[ Task3 {value} ]")
# Creating TaskGroups
@task_group_decorator(group_id="section_1")
def section_a(value):
"""TaskGroup for grouping related Tasks"""
return task_3(task_2(task_1(value)))
execution_date = pendulum.parse("20201109")
with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag:
t_start = PythonOperator(task_id="task_start", python_callable=task_start, dag=dag)
sec_1 = section_a(t_start.output)
t_end = PythonOperator(task_id="task_end", python_callable=task_end, dag=dag)
sec_1.set_downstream(t_end)
# Testing Tasks in DAG
assert set(dag.task_group.children.keys()) == {"section_1", "task_start", "task_end"}
assert set(dag.task_group.children["section_1"].children.keys()) == {
"section_1.task_2",
"section_1.task_3",
"section_1.task_1",
}
# Testing Tasks downstream
assert dag.task_dict["task_start"].downstream_task_ids == {"section_1.task_1"}
assert dag.task_dict["section_1.task_3"].downstream_task_ids == {"task_end"}
def test_task_group_context_mix():
"""Test cases to check nested TaskGroup context manager with taskgroup decorator"""
from airflow.decorators import task
def task_start():
"""Dummy Task which is First Task of Dag"""
return "[Task_start]"
def task_end():
"""Dummy Task which is Last Task of Dag"""
print("[ Task_End ]")
# Creating Tasks
@task
def task_1(value):
"""Dummy Task1"""
return f"[ Task1 {value} ]"
@task
def task_2(value):
"""Dummy Task2"""
return f"[ Task2 {value} ]"
@task
def task_3(value):
"""Dummy Task3"""
print(f"[ Task3 {value} ]")
# Creating TaskGroups
@task_group_decorator
def section_2(value):
"""TaskGroup for grouping related Tasks"""
return task_3(task_2(task_1(value)))
execution_date = pendulum.parse("20201109")
with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag:
t_start = PythonOperator(task_id="task_start", python_callable=task_start, dag=dag)
with TaskGroup("section_1", tooltip="section_1") as section_1:
sec_2 = section_2(t_start.output)
task_s1 = EmptyOperator(task_id="task_1")
task_s2 = BashOperator(task_id="task_2", bash_command="echo 1")
task_s3 = EmptyOperator(task_id="task_3")
sec_2.set_downstream(task_s1)
task_s1 >> [task_s2, task_s3]
t_end = PythonOperator(task_id="task_end", python_callable=task_end, dag=dag)
t_start >> section_1 >> t_end
node_ids = {
"id": None,
"children": [
{
"id": "section_1",
"children": [
{
"id": "section_1.section_2",
"children": [
{"id": "section_1.section_2.task_1"},
{"id": "section_1.section_2.task_2"},
{"id": "section_1.section_2.task_3"},
],
},
{"id": "section_1.task_1"},
{"id": "section_1.task_2"},
{"id": "section_1.task_3"},
{"id": "section_1.upstream_join_id"},
{"id": "section_1.downstream_join_id"},
],
},
{"id": "task_end"},
{"id": "task_start"},
],
}
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
def test_default_args():
"""Testing TaskGroup with default_args"""
execution_date = pendulum.parse("20201109")
with DAG(
dag_id="example_task_group_default_args",
start_date=execution_date,
default_args={
"owner": "dag",
},
):
with TaskGroup("group1", default_args={"owner": "group"}):
task_1 = EmptyOperator(task_id="task_1")
task_2 = EmptyOperator(task_id="task_2", owner="task")
task_3 = EmptyOperator(task_id="task_3", default_args={"owner": "task"})
assert task_1.owner == "group"
assert task_2.owner == "task"
assert task_3.owner == "task"
def test_duplicate_task_group_id():
"""Testing automatic suffix assignment for duplicate group_id"""
from airflow.decorators import task
@task(task_id="start_task")
def task_start():
"""Dummy Task which is First Task of Dag"""
print("[Task_start]")
@task(task_id="end_task")
def task_end():
"""Dummy Task which is Last Task of Dag"""
print("[Task_End]")
# Creating Tasks
@task(task_id="task")
def task_1():
"""Dummy Task1"""
print("[Task1]")
@task(task_id="task")
def task_2():
"""Dummy Task2"""
print("[Task2]")
@task(task_id="task1")
def task_3():
"""Dummy Task3"""
print("[Task3]")
@task_group_decorator("task_group1")
def task_group1():
task_start()
task_1()
task_2()
@task_group_decorator(group_id="task_group1")
def task_group2():
task_3()
@task_group_decorator(group_id="task_group1")
def task_group3():
task_end()
execution_date = pendulum.parse("20201109")
with DAG(dag_id="example_duplicate_task_group_id", start_date=execution_date, tags=["example"]) as dag:
task_group1()
task_group2()
task_group3()
node_ids = {
"id": None,
"children": [
{
"id": "task_group1",
"children": [
{"id": "task_group1.start_task"},
{"id": "task_group1.task"},
{"id": "task_group1.task__1"},
],
},
{"id": "task_group1__1", "children": [{"id": "task_group1__1.task1"}]},
{"id": "task_group1__2", "children": [{"id": "task_group1__2.end_task"}]},
],
}
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
def test_call_taskgroup_twice():
"""Test for using same taskgroup decorated function twice"""
from airflow.decorators import task
@task(task_id="start_task")
def task_start():
"""Dummy Task which is First Task of Dag"""
print("[Task_start]")
@task(task_id="end_task")
def task_end():
"""Dummy Task which is Last Task of Dag"""
print("[Task_End]")
# Creating Tasks
@task(task_id="task")
def task_1():
"""Dummy Task1"""
print("[Task1]")
@task_group_decorator
def task_group1(name: str):
print(f"Starting taskgroup {name}")
task_start()
task_1()
task_end()
execution_date = pendulum.parse("20201109")
with DAG(dag_id="example_multi_call_task_groups", start_date=execution_date, tags=["example"]) as dag:
task_group1("Call1")
task_group1("Call2")
node_ids = {
"id": None,
"children": [
{
"id": "task_group1",
"children": [
{"id": "task_group1.end_task"},
{"id": "task_group1.start_task"},
{"id": "task_group1.task"},
],
},
{
"id": "task_group1__1",
"children": [
{"id": "task_group1__1.end_task"},
{"id": "task_group1__1.start_task"},
{"id": "task_group1__1.task"},
],
},
],
}
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
def test_pass_taskgroup_output_to_task():
"""Test that the output of a task group can be passed to a task."""
from airflow.decorators import task
@task
def one():
return 1
@task_group_decorator
def addition_task_group(num):
@task
def add_one(i):
return i + 1
return add_one(num)
@task
def increment(num):
return num + 1
@dag(schedule=None, start_date=pendulum.DateTime(2022, 1, 1), default_args={"owner": "airflow"})
def wrap():
total_1 = one()
assert isinstance(total_1, XComArg)
total_2 = addition_task_group(total_1)
assert isinstance(total_2, XComArg)
total_3 = increment(total_2)
assert isinstance(total_3, XComArg)
wrap()
def test_decorator_unknown_args():
"""Test that unknown args passed to the decorator cause an error at parse time"""
with pytest.raises(TypeError):
@task_group_decorator(b=2)
def tg(): ...
def test_decorator_multiple_use_task():
from airflow.decorators import task
@dag("test-dag", start_date=DEFAULT_DATE)
def _test_dag():
@task
def t():
pass
@task_group_decorator
def tg():
for _ in range(3):
t()
t() >> tg() >> t()
test_dag = _test_dag()
assert test_dag.task_ids == [
"t", # Start end.
"tg.t",
"tg.t__1",
"tg.t__2",
"t__1", # End node.
]
def test_topological_sort1():
dag = DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"})
# A -> B
# A -> C -> D
# ordered: B, D, C, A or D, B, C, A or D, C, B, A
with dag:
op1 = EmptyOperator(task_id="A")
op2 = EmptyOperator(task_id="B")
op3 = EmptyOperator(task_id="C")
op4 = EmptyOperator(task_id="D")
[op2, op3] >> op1
op3 >> op4
topological_list = dag.task_group.topological_sort()
tasks = [op2, op3, op4]
assert topological_list[0] in tasks
tasks.remove(topological_list[0])
assert topological_list[1] in tasks
tasks.remove(topological_list[1])
assert topological_list[2] in tasks
tasks.remove(topological_list[2])
assert topological_list[3] == op1
def test_topological_sort2():
dag = DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"})
# C -> (A u B) -> D
# C -> E
# ordered: E | D, A | B, C
with dag:
op1 = EmptyOperator(task_id="A")
op2 = EmptyOperator(task_id="B")
op3 = EmptyOperator(task_id="C")
op4 = EmptyOperator(task_id="D")
op5 = EmptyOperator(task_id="E")
op3 << [op1, op2]
op4 >> [op1, op2]
op5 >> op3
topological_list = dag.task_group.topological_sort()
set1 = [op4, op5]
assert topological_list[0] in set1
set1.remove(topological_list[0])
set2 = [op1, op2]
set2.extend(set1)
assert topological_list[1] in set2
set2.remove(topological_list[1])
assert topological_list[2] in set2
set2.remove(topological_list[2])
assert topological_list[3] in set2
assert topological_list[4] == op3
def test_topological_nested_groups():
execution_date = pendulum.parse("20200101")
with DAG("test_dag_edges", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
task5 = EmptyOperator(task_id="task5")
with TaskGroup("group_a") as group_a:
with TaskGroup("group_b"):
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
task4 = EmptyOperator(task_id="task4")
task2 >> [task3, task4]
task1 >> group_a
group_a >> task5
def nested_topo(group):
return [
nested_topo(node) if isinstance(node, TaskGroup) else node for node in group.topological_sort()
]
topological_list = nested_topo(dag.task_group)
assert topological_list == [
task1,
[
[
task2,
task3,
task4,
],
],
task5,
]
def test_hierarchical_alphabetical_sort():
execution_date = pendulum.parse("20200101")
with DAG("test_dag_edges", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
task5 = EmptyOperator(task_id="task5")
with TaskGroup("group_c"):
task7 = EmptyOperator(task_id="task7")
with TaskGroup("group_b"):
task6 = EmptyOperator(task_id="task6")
with TaskGroup("group_a"):
with TaskGroup("group_d"):
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
task4 = EmptyOperator(task_id="task4")
task9 = EmptyOperator(task_id="task9")
task8 = EmptyOperator(task_id="task8")
def nested(group):
return [
nested(node) if isinstance(node, TaskGroup) else node
for node in group.hierarchical_alphabetical_sort()
]
sorted_list = nested(dag.task_group)
assert sorted_list == [
[ # group_a
[ # group_d
task2,
task3,
task4,
],
task8,
task9,
],
[task6], # group_b
[task7], # group_c
task1,
task5,
]
def test_topological_group_dep():
execution_date = pendulum.parse("20200101")
with DAG("test_dag_edges", start_date=execution_date) as dag:
task1 = EmptyOperator(task_id="task1")
task6 = EmptyOperator(task_id="task6")
with TaskGroup("group_a") as group_a:
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
with TaskGroup("group_b") as group_b:
task4 = EmptyOperator(task_id="task4")
task5 = EmptyOperator(task_id="task5")
task1 >> group_a >> group_b >> task6
def nested_topo(group):
return [
nested_topo(node) if isinstance(node, TaskGroup) else node for node in group.topological_sort()
]
topological_list = nested_topo(dag.task_group)
assert topological_list == [
task1,
[
task2,
task3,
],
[
task4,
task5,
],
task6,
]
def test_add_to_sub_group():
with DAG("test_dag", start_date=pendulum.parse("20200101")):
tg = TaskGroup("section")
task = EmptyOperator(task_id="task")
with pytest.raises(TaskAlreadyInTaskGroup) as ctx:
tg.add(task)
assert str(ctx.value) == "cannot add 'task' to 'section' (already in the DAG's root group)"
def test_add_to_another_group():
with DAG("test_dag", start_date=pendulum.parse("20200101")):
tg = TaskGroup("section_1")
with TaskGroup("section_2"):
task = EmptyOperator(task_id="task")
with pytest.raises(TaskAlreadyInTaskGroup) as ctx:
tg.add(task)
assert str(ctx.value) == "cannot add 'section_2.task' to 'section_1' (already in group 'section_2')"
def test_task_group_edge_modifier_chain():
from airflow.models.baseoperator import chain
from airflow.utils.edgemodifier import Label
with DAG(dag_id="test", start_date=pendulum.DateTime(2022, 5, 20)) as dag:
start = EmptyOperator(task_id="sleep_3_seconds")
with TaskGroup(group_id="group1") as tg:
t1 = EmptyOperator(task_id="dummy1")
t2 = EmptyOperator(task_id="dummy2")
t3 = EmptyOperator(task_id="echo_done")
# The case we are testing for is when a Label is inside a list -- meaning that we do tg.set_upstream
# instead of label.set_downstream
chain(start, [Label("branch three")], tg, t3)
assert start.downstream_task_ids == {t1.node_id, t2.node_id}
assert t3.upstream_task_ids == {t1.node_id, t2.node_id}
assert tg.upstream_task_ids == set()
assert tg.downstream_task_ids == {t3.node_id}
# Check that we can perform a topological_sort
dag.topological_sort()
def test_mapped_task_group_id_prefix_task_id():
from tests.test_utils.mock_operators import MockOperator
with DAG(dag_id="d", start_date=DEFAULT_DATE) as dag:
t1 = MockOperator.partial(task_id="t1").expand(arg1=[])
with TaskGroup("g"):
t2 = MockOperator.partial(task_id="t2").expand(arg1=[])
assert t1.task_id == "t1"
assert t2.task_id == "g.t2"
dag.get_task("t1") == t1
dag.get_task("g.t2") == t2
def test_iter_tasks():
with DAG("test_dag", start_date=pendulum.parse("20200101")) as dag:
with TaskGroup("section_1") as tg1:
EmptyOperator(task_id="task1")
with TaskGroup("section_2") as tg2:
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
mapped_bash_operator = BashOperator.partial(task_id="bash_task").expand(
bash_command=[
"echo hello 1",
"echo hello 2",
"echo hello 3",
]
)
task2 >> task3 >> mapped_bash_operator
tg1 >> tg2
root_group = dag.task_group
assert [t.task_id for t in root_group.iter_tasks()] == [
"section_1.task1",
"section_2.task2",
"section_2.task3",
"section_2.bash_task",
]
assert [t.task_id for t in tg1.iter_tasks()] == [
"section_1.task1",
]
assert [t.task_id for t in tg2.iter_tasks()] == [
"section_2.task2",
"section_2.task3",
"section_2.bash_task",
]
def test_override_dag_default_args():
with DAG(
dag_id="test_dag",
start_date=pendulum.parse("20200101"),
default_args={
"retries": 1,
"owner": "x",
},
):
with TaskGroup(
group_id="task_group",
default_args={
"owner": "y",
"execution_timeout": timedelta(seconds=10),
},
):
task = EmptyOperator(task_id="task")
assert task.retries == 1
assert task.owner == "y"
assert task.execution_timeout == timedelta(seconds=10)
def test_override_dag_default_args_in_nested_tg():
with DAG(
dag_id="test_dag",
start_date=pendulum.parse("20200101"),
default_args={
"retries": 1,
"owner": "x",
},
):
with TaskGroup(
group_id="task_group",
default_args={
"owner": "y",
"execution_timeout": timedelta(seconds=10),
},
):
with TaskGroup(group_id="nested_task_group"):
task = EmptyOperator(task_id="task")
assert task.retries == 1
assert task.owner == "y"
assert task.execution_timeout == timedelta(seconds=10)
def test_override_dag_default_args_in_multi_level_nested_tg():
with DAG(
dag_id="test_dag",
start_date=pendulum.parse("20200101"),
default_args={
"retries": 1,
"owner": "x",
},
):
with TaskGroup(
group_id="task_group",
default_args={
"owner": "y",
"execution_timeout": timedelta(seconds=10),
},
):
with TaskGroup(
group_id="first_nested_task_group",
default_args={
"owner": "z",
},
):
with TaskGroup(group_id="second_nested_task_group"):
with TaskGroup(group_id="third_nested_task_group"):
task = EmptyOperator(task_id="task")
assert task.retries == 1
assert task.owner == "z"
assert task.execution_timeout == timedelta(seconds=10)
def test_task_group_arrow_with_setups_teardowns():
with DAG(dag_id="hi", start_date=pendulum.datetime(2022, 1, 1)):
with TaskGroup(group_id="tg1") as tg1:
s1 = BaseOperator(task_id="s1")
w1 = BaseOperator(task_id="w1")
t1 = BaseOperator(task_id="t1")
s1 >> w1 >> t1.as_teardown(setups=s1)
w2 = BaseOperator(task_id="w2")
tg1 >> w2
assert t1.downstream_task_ids == set()
assert w1.downstream_task_ids == {"tg1.t1", "w2"}
def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group", start_date=pendulum.now()):
with TaskGroup("group_1") as g1:
@setup
def setup_1(): ...
@setup
def setup_2(): ...
s1 = setup_1()
s2 = setup_2()
with TaskGroup("group_2") as g2:
@teardown
def teardown_1(): ...
@teardown
def teardown_2(): ...
t1 = teardown_1()
t2 = teardown_2()
@task_decorator
def work(): ...
w1 = work()
g1 >> w1 >> g2
t1.as_teardown(setups=s1)
t2.as_teardown(setups=s2)
assert set(s1.operator.downstream_task_ids) == {"work", "group_2.teardown_1"}
assert set(s2.operator.downstream_task_ids) == {"work", "group_2.teardown_2"}
assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"}
assert set(t1.operator.downstream_task_ids) == set()
assert set(t2.operator.downstream_task_ids) == set()
def get_nodes(group):
d = task_group_to_dict(group)
new_d = {}
new_d["id"] = d["id"]
new_d["children"] = [{"id": x["id"]} for x in d["children"]]
return new_d
assert get_nodes(g1) == {
"id": "group_1",
"children": [
{"id": "group_1.setup_1"},
{"id": "group_1.setup_2"},
{"id": "group_1.downstream_join_id"},
],
}
def test_task_group_arrow_with_setup_group_deeper_setup():
"""
When recursing upstream for a non-teardown leaf, we should ignore setups that
are direct upstream of a teardown.
"""
with DAG(dag_id="setup_group_teardown_group_2", start_date=pendulum.now()):
with TaskGroup("group_1") as g1:
@setup
def setup_1(): ...
@setup
def setup_2(): ...
@teardown
def teardown_0(): ...
s1 = setup_1()
s2 = setup_2()
t0 = teardown_0()
s2 >> t0
with TaskGroup("group_2") as g2:
@teardown
def teardown_1(): ...
@teardown
def teardown_2(): ...
t1 = teardown_1()
t2 = teardown_2()
@task_decorator
def work(): ...
w1 = work()
g1 >> w1 >> g2
t1.as_teardown(setups=s1)
t2.as_teardown(setups=s2)
assert set(s1.operator.downstream_task_ids) == {"work", "group_2.teardown_1"}
assert set(s2.operator.downstream_task_ids) == {"group_1.teardown_0", "group_2.teardown_2"}
assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"}
assert set(t1.operator.downstream_task_ids) == set()
assert set(t2.operator.downstream_task_ids) == set()