blob: c3e4e26ce2794b326ede5df07178a6858aada354 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for stringified DAGs."""
from __future__ import annotations
import contextlib
import copy
import dataclasses
import importlib
import importlib.util
import json
import multiprocessing
import os
import pickle
import re
import sys
import warnings
from collections.abc import Generator
from datetime import datetime, timedelta, timezone as dt_timezone
from glob import glob
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING
from unittest import mock
import attrs
import pendulum
import pytest
from dateutil.relativedelta import FR, relativedelta
from kubernetes.client import models as k8s
import airflow
from airflow._shared.timezones import timezone
from airflow.dag_processing.dagbag import DagBag
from airflow.exceptions import (
AirflowException,
ParamValidationError,
SerializationError,
)
from airflow.models.asset import AssetModel
from airflow.models.connection import Connection
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.sdk import DAG, AssetAlias, BaseHook, WeightRule, teardown
from airflow.sdk.bases.decorator import DecoratedOperator
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS, BaseOperator
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.security import permissions
from airflow.serialization.definitions.notset import NOTSET
from airflow.serialization.enums import Encoding
from airflow.serialization.json_schema import load_dag_schema_dict
from airflow.serialization.serialized_objects import (
BaseSerialization,
SerializedBaseOperator,
SerializedDAG,
SerializedParam,
XComOperatorLink,
)
from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy, _DownstreamPriorityWeightStrategy
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.timetables.simple import NullTimetable, OnceTimetable
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.module_loading import qualname
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker, skip_if_not_on_main
from tests_common.test_utils.mock_operators import (
AirflowLink,
AirflowLink2,
CustomOperator,
GithubLink,
MockOperator,
)
from tests_common.test_utils.timetables import (
CustomSerializationTimetable,
cron_timetable,
delta_timetable,
)
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
@contextlib.contextmanager
def operator_defaults(overrides):
"""
Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit.
Example:
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
# Test code with modified operator defaults
"""
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
original_values = {}
try:
# Store original values and apply overrides
for key, value in overrides.items():
original_values[key] = OPERATOR_DEFAULTS.get(key)
OPERATOR_DEFAULTS[key] = value
# Clear the cache to ensure fresh generation
SerializedBaseOperator.generate_client_defaults.cache_clear()
yield
finally:
# Cleanup: restore original values
for key, original_value in original_values.items():
if original_value is None and key in OPERATOR_DEFAULTS:
# Key didn't exist originally, remove it
del OPERATOR_DEFAULTS[key]
else:
# Restore original value
OPERATOR_DEFAULTS[key] = original_value
# Clear cache again to restore normal behavior
SerializedBaseOperator.generate_client_defaults.cache_clear()
AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
executor_config_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name="my-name"),
spec=k8s.V1PodSpec(
containers=[
k8s.V1Container(
name="base",
volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")],
)
]
),
)
TYPE = Encoding.TYPE
VAR = Encoding.VAR
serialized_simple_dag_ground_truth = {
"__version": 3,
"dag": {
"default_args": {
"__type": "dict",
"__var": {
"depends_on_past": False,
"retries": 1,
"retry_delay": {"__type": "timedelta", "__var": 240.0},
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
},
},
"start_date": 1564617600.0,
"timetable": {
"__type": "airflow.timetables.interval.DeltaDataIntervalTimetable",
"__var": {
"delta": 86400.0,
},
},
"task_group": {
"_group_id": None,
"group_display_name": "",
"prefix_group_id": True,
"children": {
"bash_task": ("operator", "bash_task"),
"custom_task": ("operator", "custom_task"),
},
"tooltip": "",
"ui_color": "CornflowerBlue",
"ui_fgcolor": "#000",
"upstream_group_ids": [],
"downstream_group_ids": [],
"upstream_task_ids": [],
"downstream_task_ids": [],
},
"is_paused_upon_creation": False,
"dag_id": "simple_dag",
"deadline": None,
"doc_md": "### DAG Tutorial Documentation",
"fileloc": None,
"_processor_dags_folder": (
AIRFLOW_REPO_ROOT_PATH / "airflow-core" / "tests" / "unit" / "dags"
).as_posix(),
"tasks": [
{
"__type": "operator",
"__var": {
"task_id": "bash_task",
"retries": 1,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"ui_color": "#f0ede4",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {
"bash_command": "bash",
"env": "json",
},
"bash_command": "echo {{ task.task_id }}",
"task_type": "BashOperator",
"_task_module": "airflow.providers.standard.operators.bash",
"_task_display_name": "my_bash_task",
"owner": "airflow1",
"pool": "pool1",
"executor_config": {
"__type": "dict",
"__var": {
"pod_override": {
"__type": "k8s.V1Pod",
"__var": PodGenerator.serialize_pod(executor_config_pod),
}
},
},
"doc_md": "### Task Tutorial Documentation",
"_needs_expansion": False,
"inlets": [
{
"__type": "asset",
"__var": {
"extra": {},
"group": "asset",
"name": "asset-1",
"uri": "asset-1",
},
},
{
"__type": "asset_alias",
"__var": {"group": "asset", "name": "alias-name"},
},
],
"outlets": [
{
"__type": "asset",
"__var": {
"extra": {},
"group": "asset",
"name": "asset-2",
"uri": "asset-2",
},
},
],
},
},
{
"__type": "operator",
"__var": {
"task_id": "custom_task",
"retries": 1,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"_operator_extra_links": {"Google Custom": "_link_CustomOpLink"},
"template_fields": ["bash_command"],
"task_type": "CustomOperator",
"_operator_name": "@custom",
"_task_module": "tests_common.test_utils.mock_operators",
"_needs_expansion": False,
},
},
],
"timezone": "UTC",
"access_control": {
"__type": "dict",
"__var": {
"test_role": {
"__type": "dict",
"__var": {
"DAGs": {
"__type": "set",
"__var": [
permissions.ACTION_CAN_READ,
permissions.ACTION_CAN_EDIT,
],
}
},
},
},
},
"edge_info": {},
"dag_dependencies": [
{
"dependency_id": '{"name": "asset-2", "uri": "asset-2"}',
"dependency_type": "asset",
"label": "asset-2",
"source": "simple_dag",
"target": "asset",
},
],
"params": [],
},
}
CUSTOM_TIMETABLE_SERIALIZED = {
"__type": "tests_common.test_utils.timetables.CustomSerializationTimetable",
"__var": {"value": "foo"},
}
@pytest.fixture
def testing_assets(session):
from tests_common.test_utils.db import clear_db_assets
assets = [Asset(name=f"asset{i}", uri=f"test://asset{i}/") for i in range(1, 5)]
session.add_all([AssetModel(id=i, name=f"asset{i}", uri=f"test://asset{i}/") for i in range(1, 5)])
session.commit()
yield assets
clear_db_assets()
def make_simple_dag():
"""Make very simple DAG to verify serialization result."""
with DAG(
dag_id="simple_dag",
schedule=timedelta(days=1),
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=4),
"max_retry_delay": timedelta(minutes=10),
"depends_on_past": False,
},
start_date=datetime(2019, 8, 1),
is_paused_upon_creation=False,
access_control={"test_role": {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}},
doc_md="### DAG Tutorial Documentation",
) as dag:
CustomOperator(task_id="custom_task")
BashOperator(
task_id="bash_task",
bash_command="echo {{ task.task_id }}",
owner="airflow1",
executor_config={"pod_override": executor_config_pod},
doc_md="### Task Tutorial Documentation",
inlets=[Asset("asset-1"), AssetAlias(name="alias-name")],
outlets=Asset("asset-2"),
pool="pool1",
task_display_name="my_bash_task",
)
return dag
def make_user_defined_macro_filter_dag():
"""
Make DAGs with user defined macros and filters using locally defined methods.
For Webserver, we do not include ``user_defined_macros`` & ``user_defined_filters``.
The examples here test:
(1) functions can be successfully displayed on UI;
(2) templates with function macros have been rendered before serialization.
"""
# TODO (GH-52141): Since the worker would not have access to the database in
# production anyway, we should rewrite this test to better match reality.
def compute_last_dagrun(dag: DAG):
from airflow.models.dag import get_last_dagrun
from airflow.utils.session import create_session
with create_session() as session:
return get_last_dagrun(dag.dag_id, session=session, include_manually_triggered=True)
default_args = {"start_date": datetime(2019, 7, 10)}
dag = DAG(
"user_defined_macro_filter_dag",
schedule=None,
default_args=default_args,
user_defined_macros={
"last_dagrun": compute_last_dagrun,
},
user_defined_filters={"hello": lambda name: f"Hello {name}"},
catchup=False,
)
BashOperator(
task_id="echo",
bash_command='echo "{{ last_dagrun(dag) }}"',
dag=dag,
)
return {dag.dag_id: dag}
def get_excluded_patterns() -> Generator[str, None, None]:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
all_providers = json.loads(
(AIRFLOW_REPO_ROOT_PATH / "generated" / "provider_dependencies.json").read_text()
)
for provider, provider_info in all_providers.items():
if python_version in provider_info.get("excluded-python-versions"):
provider_path = provider.replace(".", "/")
yield f"providers/{provider_path}"
current_python_version = sys.version_info[:2]
if current_python_version >= (3, 13):
# We should remove google when ray is fixed to work with Python 3.13
# and yandex when it is fixed to work with Python 3.13
yield "providers/google/tests/system/google/"
yield "providers/yandex/tests/system/yandex/"
def collect_dags(dag_folder=None):
"""Collects DAGs to test."""
dags = {}
import_errors = {}
dags.update({"simple_dag": make_simple_dag()})
dags.update(make_user_defined_macro_filter_dag())
if dag_folder is None:
patterns = [
"airflow-core/src/airflow/example_dags",
# For now include amazon directly because they have many dags and are all serializing without error
"providers/amazon/tests/system/*/*/",
"providers/*/tests/system/*/",
"providers/*/*/tests/system/*/*/",
]
else:
if isinstance(dag_folder, (list, tuple)):
patterns = dag_folder
else:
patterns = [dag_folder]
excluded_patterns = [
f"{AIRFLOW_REPO_ROOT_PATH}/{excluded_pattern}" for excluded_pattern in get_excluded_patterns()
]
for pattern in patterns:
for directory in glob(f"{AIRFLOW_REPO_ROOT_PATH}/{pattern}"):
if any([directory.startswith(excluded_pattern) for excluded_pattern in excluded_patterns]):
continue
dagbag = DagBag(directory, include_examples=False)
dags.update(dagbag.dags)
import_errors.update(dagbag.import_errors)
return dags, import_errors
def get_timetable_based_simple_dag(timetable):
"""Create a simple_dag variant that uses a timetable."""
dag = make_simple_dag()
dag.timetable = timetable
return dag
def serialize_subprocess(queue, dag_folder):
"""Validate pickle in a subprocess."""
dags, _ = collect_dags(dag_folder)
for dag in dags.values():
queue.put(SerializedDAG.to_json(dag))
queue.put(None)
@pytest.fixture
def timetable_plugin(monkeypatch):
"""Patch plugins manager to always and only return our custom timetable."""
from airflow import plugins_manager
monkeypatch.setattr(plugins_manager, "initialize_timetables_plugins", lambda: None)
monkeypatch.setattr(
plugins_manager,
"timetable_classes",
{"tests_common.test_utils.timetables.CustomSerializationTimetable": CustomSerializationTimetable},
)
class TestStringifiedDAGs:
"""Unit tests for stringified DAGs."""
@pytest.fixture(autouse=True)
def setup_test_cases(self):
with mock.patch.object(BaseHook, "get_connection") as m:
m.return_value = Connection(
extra=(
"{"
'"project_id": "mock", '
'"location": "mock", '
'"instance": "mock", '
'"database_type": "postgres", '
'"use_proxy": "False", '
'"use_ssl": "False"'
"}"
)
)
# Skip that test if latest botocore is used - it reads all example dags and in case latest botocore
# is upgraded to latest, usually aiobotocore can't be installed and some of the system tests will fail with
# import errors. Also skip if not running on main branch - some of the example dags might fail due to
# outdated imports in past branches
@pytest.mark.skipif(
os.environ.get("UPGRADE_BOTO", "") == "true",
reason="This test is skipped when latest botocore is installed",
)
@skip_if_force_lowest_dependencies_marker
@skip_if_not_on_main
@pytest.mark.db_test
def test_serialization(self):
"""Serialization and deserialization should work for every DAG and Operator."""
with warnings.catch_warnings():
dags, import_errors = collect_dags()
serialized_dags = {}
for v in dags.values():
dag = SerializedDAG.to_dict(v)
SerializedDAG.validate_schema(dag)
serialized_dags[v.dag_id] = dag
# Ignore some errors.
import_errors = {
file: error
for file, error in import_errors.items()
# Don't worry about warnings, we only care about errors here -- otherwise
# AirflowProviderDeprecationWarning etc show up in import_errors, and being aware of all of those is
# not relevant to this test; we only care about actual errors
if "airflow.exceptions.AirflowProviderDeprecationWarning" not in error
# TODO: TaskSDK
if "`use_airflow_context=True` is not yet implemented" not in error
# This "looks" like a problem, but is just a quirk of the parse-all-dags-in-one-process we do
# in this test
if "AirflowDagDuplicatedIdException: Ignoring DAG example_sagemaker" not in error
}
# Let's not be exact about this, but if everything fails to parse we should fail this test too
assert import_errors == {}
assert len(dags) > 100
# Compares with the ground truth of JSON string.
actual, expected = self.prepare_ser_dags_for_comparison(
actual=serialized_dags["simple_dag"],
expected=serialized_simple_dag_ground_truth,
)
assert actual == expected
@pytest.mark.db_test
@pytest.mark.parametrize(
("timetable", "serialized_timetable"),
[
(
cron_timetable("0 0 * * *"),
{
"__type": "airflow.timetables.interval.CronDataIntervalTimetable",
"__var": {"expression": "0 0 * * *", "timezone": "UTC"},
},
),
(
CustomSerializationTimetable("foo"),
CUSTOM_TIMETABLE_SERIALIZED,
),
],
)
@pytest.mark.usefixtures("timetable_plugin")
def test_dag_serialization_to_timetable(self, timetable, serialized_timetable):
"""Verify a timetable-backed DAG is serialized correctly."""
dag = get_timetable_based_simple_dag(timetable)
serialized_dag = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(serialized_dag)
expected = copy.deepcopy(serialized_simple_dag_ground_truth)
expected["dag"]["timetable"] = serialized_timetable
# these tasks are not mapped / in mapped task group
for task in expected["dag"]["tasks"]:
task["__var"]["_needs_expansion"] = False
actual, expected = self.prepare_ser_dags_for_comparison(
actual=serialized_dag,
expected=expected,
)
assert actual == expected
@pytest.mark.db_test
def test_dag_serialization_preserves_empty_access_roles(self):
"""Verify that an explicitly empty access_control dict is preserved."""
dag = make_simple_dag()
dag.access_control = {}
serialized_dag = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(serialized_dag)
assert serialized_dag["dag"]["access_control"] == {
"__type": "dict",
"__var": {},
}
@pytest.mark.db_test
def test_dag_serialization_unregistered_custom_timetable(self):
"""Verify serialization fails without timetable registration."""
dag = get_timetable_based_simple_dag(CustomSerializationTimetable("bar"))
with pytest.raises(SerializationError) as ctx:
SerializedDAG.to_dict(dag)
message = (
"Failed to serialize DAG 'simple_dag': Timetable class "
"'tests_common.test_utils.timetables.CustomSerializationTimetable' "
"is not registered or "
"you have a top level database access that disrupted the session. "
"Please check the airflow best practices documentation."
)
assert str(ctx.value) == message
def prepare_ser_dags_for_comparison(self, actual, expected):
"""Verify serialized DAGs match the ground truth."""
assert actual["dag"]["fileloc"].split("/")[-1] == "test_dag_serialization.py"
actual["dag"]["fileloc"] = None
def sorted_serialized_dag(dag_dict: dict):
"""
Sorts the "tasks" list and "access_control" permissions in the
serialised dag python dictionary. This is needed as the order of
items should not matter but assertEqual would fail if the order of
items changes in the dag dictionary
"""
tasks = []
for task in sorted(dag_dict["dag"]["tasks"], key=lambda x: x["__var"]["task_id"]):
task["__var"] = dict(sorted(task["__var"].items(), key=lambda x: x[0]))
tasks.append(task)
dag_dict["dag"]["tasks"] = tasks
if "access_control" in dag_dict["dag"]:
dag_dict["dag"]["access_control"]["__var"]["test_role"]["__var"] = sorted(
dag_dict["dag"]["access_control"]["__var"]["test_role"]["__var"]
)
return dag_dict
expected = copy.deepcopy(expected)
# by roundtripping to json we get a cleaner diff
# if not doing this, we get false alarms such as "__var" != VAR
actual = json.loads(json.dumps(sorted_serialized_dag(actual)))
expected = json.loads(json.dumps(sorted_serialized_dag(expected)))
return actual, expected
@pytest.mark.db_test
def test_deserialization_across_process(self):
"""A serialized DAG can be deserialized in another process."""
# Since we need to parse the dags twice here (once in the subprocess,
# and once here to get a DAG to compare to) we don't want to load all
# dags.
queue = multiprocessing.Queue()
proc = multiprocessing.Process(target=serialize_subprocess, args=(queue, "airflow/example_dags"))
proc.daemon = True
proc.start()
stringified_dags = {}
while True:
v = queue.get()
if v is None:
break
dag = SerializedDAG.from_json(v)
assert isinstance(dag, SerializedDAG)
stringified_dags[dag.dag_id] = dag
dags, _ = collect_dags("airflow/example_dags")
assert set(stringified_dags.keys()) == set(dags.keys())
# Verify deserialized DAGs.
for dag_id in stringified_dags:
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
@skip_if_force_lowest_dependencies_marker
@pytest.mark.db_test
def test_roundtrip_provider_example_dags(self):
dags, _ = collect_dags(
[
"providers/*/src/airflow/providers/*/example_dags",
"providers/*/src/airflow/providers/*/*/example_dags",
]
)
# Verify deserialized DAGs.
for dag in dags.values():
serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(serialized_dag, dag)
# Let's not be exact about this, but if everything fails to parse we should fail this test too
assert len(dags) >= 7
@pytest.mark.db_test
@pytest.mark.parametrize(
"timetable",
[cron_timetable("0 0 * * *"), CustomSerializationTimetable("foo")],
)
@pytest.mark.usefixtures("timetable_plugin")
def test_dag_roundtrip_from_timetable(self, timetable):
"""Verify a timetable-backed serialization can be deserialized."""
dag = get_timetable_based_simple_dag(timetable)
roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(roundtripped, dag)
def validate_deserialized_dag(self, serialized_dag: SerializedDAG, dag: DAG):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags
"""
exclusion_list = {
# Doesn't implement __eq__ properly. Check manually.
"timetable",
"timezone",
# Need to check fields in it, to exclude functions.
"default_args",
"task_group",
"params",
"_processor_dags_folder",
}
fields_to_check = dag.get_serialized_fields() - exclusion_list
for field in fields_to_check:
actual = getattr(serialized_dag, field)
expected = getattr(dag, field, None)
assert actual == expected, f"{dag.dag_id}.{field} does not match"
# _processor_dags_folder is only populated at serialization time
# it's only used when relying on serialized dag to determine a dag's relative path
assert (
serialized_dag._processor_dags_folder
== (AIRFLOW_REPO_ROOT_PATH / "airflow-core" / "tests" / "unit" / "dags").as_posix()
)
if dag.default_args:
for k, v in dag.default_args.items():
if callable(v):
# Check we stored _something_.
assert k in serialized_dag.default_args
else:
assert v == serialized_dag.default_args[k], (
f"{dag.dag_id}.default_args[{k}] does not match"
)
assert serialized_dag.timetable.summary == dag.timetable.summary
assert serialized_dag.timetable.serialize() == dag.timetable.serialize()
assert serialized_dag.timezone == dag.timezone
for task_id in dag.task_ids:
self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id))
def validate_deserialized_task(
self,
serialized_task,
task,
):
"""Verify non-Airflow operators are casted to BaseOperator or MappedOperator."""
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator
from airflow.sdk import BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
assert isinstance(task, (BaseOperator, MappedOperator))
# Every task should have a task_group property -- even if it's the DAG's root task group
assert serialized_task.task_group
if isinstance(task, BaseOperator):
assert isinstance(serialized_task, SerializedBaseOperator)
fields_to_check = task.get_serialized_fields() - {
# Checked separately
"task_type",
"_operator_name",
# Type is excluded, so don't check it
"_log",
# List vs tuple. Check separately
"template_ext",
"template_fields",
# We store the string, real dag has the actual code
"_pre_execute_hook",
"_post_execute_hook",
# Checked separately
"resources",
"on_failure_fail_dagrun",
"_needs_expansion",
"_is_sensor",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, SchedulerMappedOperator)
fields_to_check = {f.name for f in attrs.fields(MappedOperator)}
fields_to_check -= {
"map_index_template",
# Matching logic in BaseOperator.get_serialized_fields().
"dag",
"task_group",
# List vs tuple. Check separately.
"operator_extra_links",
"template_ext",
"template_fields",
# Checked separately.
"operator_class",
"partial_kwargs",
"expand_input",
}
assert serialized_task.task_type == task.task_type
assert set(serialized_task.template_ext) == set(task.template_ext)
assert set(serialized_task.template_fields) == set(task.template_fields)
assert serialized_task.upstream_task_ids == task.upstream_task_ids
assert serialized_task.downstream_task_ids == task.downstream_task_ids
for field in fields_to_check:
assert getattr(serialized_task, field) == getattr(task, field), (
f"{task.dag.dag_id}.{task.task_id}.{field} does not match"
)
if serialized_task.resources is None:
assert task.resources is None or task.resources == []
else:
assert serialized_task.resources == task.resources
# `deps` are set in the Scheduler's BaseOperator as that is where we need to evaluate deps
# so only serialized tasks that are sensors should have the ReadyToRescheduleDep.
if task._is_sensor:
assert ReadyToRescheduleDep() in serialized_task.deps
else:
assert ReadyToRescheduleDep() not in serialized_task.deps
# Ugly hack as some operators override params var in their init
if isinstance(task.params, ParamsDict) and isinstance(serialized_task.params, ParamsDict):
assert serialized_task.params.dump() == task.params.dump()
if isinstance(task, MappedOperator):
# MappedOperator.operator_class now stores only minimal type information
# for memory efficiency (task_type and _operator_name).
serialized_task.operator_class["task_type"] == type(task).__name__
if isinstance(serialized_task.operator_class, DecoratedOperator):
serialized_task.operator_class["_operator_name"] == task._operator_name
# Serialization cleans up default values in partial_kwargs, this
# adds them back to both sides.
default_partial_kwargs = (
BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
)
# These are added in `_TaskDecorator` e.g. when @setup or @teardown task is passed
default_decorator_partial_kwargs = {
"is_setup": False,
"is_teardown": False,
"on_failure_fail_dagrun": False,
}
serialized_partial_kwargs = {
**default_partial_kwargs,
**default_decorator_partial_kwargs,
**serialized_task.partial_kwargs,
}
original_partial_kwargs = {
**default_partial_kwargs,
**default_decorator_partial_kwargs,
**task.partial_kwargs,
}
assert serialized_partial_kwargs == original_partial_kwargs
# ExpandInputs have different classes between scheduler and definition
assert attrs.asdict(serialized_task._get_specified_expand_input()) == attrs.asdict(
task._get_specified_expand_input()
)
@pytest.mark.parametrize(
("dag_start_date", "task_start_date", "expected_task_start_date"),
[
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
None,
datetime(2019, 8, 1, tzinfo=timezone.utc),
),
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
datetime(2019, 8, 2, tzinfo=timezone.utc),
datetime(2019, 8, 2, tzinfo=timezone.utc),
),
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
datetime(2019, 7, 30, tzinfo=timezone.utc),
datetime(2019, 8, 1, tzinfo=timezone.utc),
),
(
datetime(2019, 8, 1, tzinfo=dt_timezone(timedelta(hours=1))),
datetime(2019, 7, 30, tzinfo=dt_timezone(timedelta(hours=1))),
datetime(2019, 8, 1, tzinfo=dt_timezone(timedelta(hours=1))),
),
(
pendulum.datetime(2019, 8, 1, tz="UTC"),
None,
pendulum.datetime(2019, 8, 1, tz="UTC"),
),
],
)
def test_deserialization_start_date(self, dag_start_date, task_start_date, expected_task_start_date):
dag = DAG(dag_id="simple_dag", schedule=None, start_date=dag_start_date)
BaseOperator(task_id="simple_task", dag=dag, start_date=task_start_date)
serialized_dag = SerializedDAG.to_dict(dag)
if not task_start_date or dag_start_date >= task_start_date:
# If dag.start_date > task.start_date -> task.start_date=dag.start_date
# because of the logic in dag.add_task()
assert "start_date" not in serialized_dag["dag"]["tasks"][0]["__var"]
else:
assert "start_date" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
assert simple_task.start_date == expected_task_start_date
def test_deserialization_with_dag_context(self):
with DAG(
dag_id="simple_dag",
schedule=None,
start_date=datetime(2019, 8, 1, tzinfo=timezone.utc),
) as dag:
BaseOperator(task_id="simple_task")
# should not raise RuntimeError: dictionary changed size during iteration
SerializedDAG.to_dict(dag)
@pytest.mark.parametrize(
("dag_end_date", "task_end_date", "expected_task_end_date"),
[
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
None,
datetime(2019, 8, 1, tzinfo=timezone.utc),
),
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
datetime(2019, 8, 2, tzinfo=timezone.utc),
datetime(2019, 8, 1, tzinfo=timezone.utc),
),
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
datetime(2019, 7, 30, tzinfo=timezone.utc),
datetime(2019, 7, 30, tzinfo=timezone.utc),
),
],
)
def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date):
dag = DAG(
dag_id="simple_dag",
schedule=None,
start_date=datetime(2019, 8, 1),
end_date=dag_end_date,
)
BaseOperator(task_id="simple_task", dag=dag, end_date=task_end_date)
serialized_dag = SerializedDAG.to_dict(dag)
if not task_end_date or dag_end_date <= task_end_date:
# If dag.end_date < task.end_date -> task.end_date=dag.end_date
# because of the logic in dag.add_task()
assert "end_date" not in serialized_dag["dag"]["tasks"][0]["__var"]
else:
assert "end_date" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
assert simple_task.end_date == expected_task_end_date
@pytest.mark.parametrize(
("serialized_timetable", "expected_timetable"),
[
(
{"__type": "airflow.timetables.simple.NullTimetable", "__var": {}},
NullTimetable(),
),
(
{
"__type": "airflow.timetables.interval.CronDataIntervalTimetable",
"__var": {"expression": "@weekly", "timezone": "UTC"},
},
cron_timetable("0 0 * * 0"),
),
(
{"__type": "airflow.timetables.simple.OnceTimetable", "__var": {}},
OnceTimetable(),
),
(
{
"__type": "airflow.timetables.interval.DeltaDataIntervalTimetable",
"__var": {"delta": 86400.0},
},
delta_timetable(timedelta(days=1)),
),
(CUSTOM_TIMETABLE_SERIALIZED, CustomSerializationTimetable("foo")),
],
)
@pytest.mark.usefixtures("timetable_plugin")
def test_deserialization_timetable(
self,
serialized_timetable,
expected_timetable,
):
serialized = {
"__version": 3,
"dag": {
"default_args": {"__type": "dict", "__var": {}},
"dag_id": "simple_dag",
"fileloc": __file__,
"tasks": [],
"timezone": "UTC",
"timetable": serialized_timetable,
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
assert dag.timetable == expected_timetable
@pytest.mark.parametrize(
("serialized_timetable", "expected_timetable_summary"),
[
(
{"__type": "airflow.timetables.simple.NullTimetable", "__var": {}},
"None",
),
(
{
"__type": "airflow.timetables.interval.CronDataIntervalTimetable",
"__var": {"expression": "@weekly", "timezone": "UTC"},
},
"0 0 * * 0",
),
(
{"__type": "airflow.timetables.simple.OnceTimetable", "__var": {}},
"@once",
),
(
{
"__type": "airflow.timetables.interval.DeltaDataIntervalTimetable",
"__var": {"delta": 86400.0},
},
"1 day, 0:00:00",
),
(CUSTOM_TIMETABLE_SERIALIZED, "CustomSerializationTimetable('foo')"),
],
)
@pytest.mark.usefixtures("timetable_plugin")
def test_deserialization_timetable_summary(
self,
serialized_timetable,
expected_timetable_summary,
):
serialized = {
"__version": 3,
"dag": {
"default_args": {"__type": "dict", "__var": {}},
"dag_id": "simple_dag",
"fileloc": __file__,
"tasks": [],
"timezone": "UTC",
"timetable": serialized_timetable,
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
assert dag.timetable_summary == expected_timetable_summary
def test_deserialization_timetable_unregistered(self):
serialized = {
"__version": 3,
"dag": {
"default_args": {"__type": "dict", "__var": {}},
"dag_id": "simple_dag",
"fileloc": __file__,
"tasks": [],
"timezone": "UTC",
"timetable": CUSTOM_TIMETABLE_SERIALIZED,
},
}
SerializedDAG.validate_schema(serialized)
message = (
"Timetable class "
"'tests_common.test_utils.timetables.CustomSerializationTimetable' "
"is not registered or "
"you have a top level database access that disrupted the session. "
"Please check the airflow best practices documentation."
)
with pytest.raises(ValueError, match=message):
SerializedDAG.from_dict(serialized)
@pytest.mark.parametrize(
("val", "expected"),
[
(
relativedelta(days=-1),
{"__type": "relativedelta", "__var": {"days": -1}},
),
(
relativedelta(month=1, days=-1),
{"__type": "relativedelta", "__var": {"month": 1, "days": -1}},
),
# Every friday
(
relativedelta(weekday=FR),
{"__type": "relativedelta", "__var": {"weekday": [4]}},
),
# Every second friday
(
relativedelta(weekday=FR(2)),
{"__type": "relativedelta", "__var": {"weekday": [4, 2]}},
),
],
)
def test_roundtrip_relativedelta(self, val, expected):
serialized = SerializedDAG.serialize(val)
assert serialized == expected
round_tripped = SerializedDAG.deserialize(serialized)
assert val == round_tripped
@pytest.mark.parametrize(
("val", "expected_val"),
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
({"param_1": {1, 2, 3}}, ParamValidationError),
],
)
def test_dag_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
if expected_val == ParamValidationError:
with pytest.raises(ParamValidationError):
dag = DAG(dag_id="simple_dag", schedule=None, params=val)
# further tests not relevant
return
dag = DAG(dag_id="simple_dag", schedule=None, params=val)
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1))
serialized_dag_json = SerializedDAG.to_json(dag)
serialized_dag = json.loads(serialized_dag_json)
assert "params" in serialized_dag["dag"]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
assert expected_val == deserialized_dag.params.dump()
assert expected_val == deserialized_simple_task.params.dump()
def test_invalid_params(self):
"""
Test to make sure that only native Param objects are being passed as dag or task params
"""
class S3Param(Param):
def __init__(self, path: str):
schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"}
super().__init__(default=path, schema=schema)
dag = DAG(
dag_id="simple_dag",
schedule=None,
params={"path": S3Param("s3://my_bucket/my_path")},
)
with pytest.raises(SerializationError):
SerializedDAG.to_dict(dag)
dag = DAG(dag_id="simple_dag", schedule=None)
BaseOperator(
task_id="simple_task",
dag=dag,
start_date=datetime(2019, 8, 1),
params={"path": S3Param("s3://my_bucket/my_path")},
)
@pytest.mark.parametrize(
"param",
[
Param("my value", description="hello", schema={"type": "string"}),
Param("my value", description="hello"),
Param(None, description=None),
Param([True], type="array", items={"type": "boolean"}),
Param(),
],
)
def test_full_param_roundtrip(self, param: Param):
"""
Test to make sure that only native Param objects are being passed as dag or task params
"""
sdk_dag = DAG(dag_id="simple_dag", schedule=None, params={"my_param": param})
serialized_json = SerializedDAG.to_json(sdk_dag)
serialized = json.loads(serialized_json)
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
assert dag.params.get_param("my_param").value == param.value
observed_param = dag.params.get_param("my_param")
assert isinstance(observed_param, SerializedParam)
assert observed_param.description == param.description
assert observed_param.schema == param.schema
assert observed_param.dump() == {
"value": None if param.value is NOTSET else param.value,
"schema": param.schema,
"description": param.description,
}
@pytest.mark.parametrize(
("val", "expected_val"),
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
({"param_1": {1, 2, 3}}, ParamValidationError),
],
)
def test_task_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
dag = DAG(dag_id="simple_dag", schedule=None)
if expected_val == ParamValidationError:
with pytest.raises(ParamValidationError):
BaseOperator(
task_id="simple_task",
dag=dag,
params=val,
start_date=datetime(2019, 8, 1),
)
# further tests not relevant
return
BaseOperator(
task_id="simple_task",
dag=dag,
params=val,
start_date=datetime(2019, 8, 1),
)
serialized_dag = SerializedDAG.to_dict(dag)
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
if val:
assert "params" in serialized_dag["dag"]["tasks"][0]["__var"]
else:
assert "params" not in serialized_dag["dag"]["tasks"][0]["__var"]
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
assert expected_val == deserialized_simple_task.params.dump()
@pytest.mark.db_test
@pytest.mark.parametrize(
("bash_command", "serialized_links", "links"),
[
pytest.param(
"true",
{"Google Custom": "_link_CustomOpLink"},
{"Google Custom": "http://google.com/custom_base_link?search=true"},
id="non-indexed-link",
),
pytest.param(
["echo", "true"],
{"BigQuery Console #1": "bigquery_1", "BigQuery Console #2": "bigquery_2"},
{
"BigQuery Console #1": "https://console.cloud.google.com/bigquery?j=echo",
"BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=true",
},
id="multiple-indexed-links",
),
],
)
def test_extra_serialized_field_and_operator_links(
self, bash_command, serialized_links, links, dag_maker
):
"""
Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.
This tests also depends on GoogleLink() registered as a plugin
in tests/plugins/test_plugin.py
The function tests that if extra operator links are registered in plugin
in ``operator_extra_links`` and the same is also defined in
the Operator in ``BaseOperator.operator_extra_links``, it has the correct
extra link.
If CustomOperator is called with a string argument for bash_command it
has a single link, if called with an array it has one link per element.
We use this to test the serialization of link data.
"""
test_date = timezone.DateTime(2019, 8, 1, tzinfo=timezone.utc)
with dag_maker(dag_id="simple_dag", start_date=test_date) as dag:
CustomOperator(task_id="simple_task", bash_command=bash_command)
serialized_dag = SerializedDAG.to_dict(dag)
assert "bash_command" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
assert getattr(simple_task, "bash_command") == bash_command
#########################################################
# Verify Operator Links work with Serialized Operator
#########################################################
# Check Serialized version of operator link only contains the inbuilt Op Link
assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"] == serialized_links
# Test all the extra_links are set
assert simple_task.extra_links == sorted({*links, "airflow", "github", "google"})
dr = dag_maker.create_dagrun(logical_date=test_date)
(ti,) = dr.task_instances
XComModel.set(
key="search_query",
value=bash_command,
task_id=simple_task.task_id,
dag_id=simple_task.dag_id,
run_id=dr.run_id,
)
# Test Deserialized inbuilt link
for i, (name, expected) in enumerate(links.items()):
# staging the part where a task at runtime pushes xcom for extra links
XComModel.set(
key=simple_task.operator_extra_links[i].xcom_key,
value=expected,
task_id=simple_task.task_id,
dag_id=simple_task.dag_id,
run_id=dr.run_id,
)
link = simple_task.get_extra_links(ti, name)
assert link == expected
current_python_version = sys.version_info[:2]
if current_python_version >= (3, 13):
# TODO(potiuk) We should bring it back when ray is supported on Python 3.13
# Test Deserialized link registered via Airflow Plugin
from tests_common.test_utils.mock_operators import GoogleLink
link = simple_task.get_extra_links(ti, GoogleLink.name)
assert link == "https://www.google.com"
class ClassWithCustomAttributes:
"""
Class for testing purpose: allows to create objects with custom attributes in one single statement.
"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __str__(self):
return f"{self.__class__.__name__}({str(self.__dict__)})"
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __hash__(self):
return hash(self.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
@pytest.mark.parametrize(
("templated_field", "expected_field"),
[
(None, None),
([], []),
({}, {}),
("{{ task.task_id }}", "{{ task.task_id }}"),
(["{{ task.task_id }}", "{{ task.task_id }}"]),
({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}),
(
{"foo": {"bar": "{{ task.task_id }}"}},
{"foo": {"bar": "{{ task.task_id }}"}},
),
(
[
{"foo1": {"bar": "{{ task.task_id }}"}},
{"foo2": {"bar": "{{ task.task_id }}"}},
],
[
{"foo1": {"bar": "{{ task.task_id }}"}},
{"foo2": {"bar": "{{ task.task_id }}"}},
],
),
(
{"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}},
{"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}},
),
(
ClassWithCustomAttributes(
att1="{{ task.task_id }}",
att2="{{ task.task_id }}",
template_fields=["att1"],
),
"ClassWithCustomAttributes("
"{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})",
),
(
ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
att1="{{ task.task_id }}",
att2="{{ task.task_id }}",
template_fields=["att1"],
),
nested2=ClassWithCustomAttributes(
att3="{{ task.task_id }}",
att4="{{ task.task_id }}",
template_fields=["att3"],
),
template_fields=["nested1"],
),
"ClassWithCustomAttributes("
"{'nested1': ClassWithCustomAttributes({'att1': '{{ task.task_id }}', "
"'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
"'nested2': ClassWithCustomAttributes({'att3': '{{ task.task_id }}', 'att4': "
"'{{ task.task_id }}', 'template_fields': ['att3']}), 'template_fields': ['nested1']})",
),
],
)
def test_templated_fields_exist_in_serialized_dag(self, templated_field, expected_field):
"""
Test that templated_fields exists for all Operators in Serialized DAG
Since we don't want to inflate arbitrary python objects (it poses a RCE/security risk etc.)
we want check that non-"basic" objects are turned in to strings after deserializing.
"""
dag = DAG(
"test_serialized_template_fields",
schedule=None,
start_date=datetime(2019, 8, 1),
)
with dag:
BashOperator(task_id="test", bash_command=templated_field)
serialized_dag = SerializedDAG.to_dict(dag)
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_test_task = deserialized_dag.task_dict["test"]
assert expected_field == getattr(deserialized_test_task, "bash_command")
def test_dag_serialized_fields_with_schema(self):
"""
Additional Properties are disabled on DAGs. This test verifies that all the
keys in DAG.get_serialized_fields are listed in Schema definition.
"""
dag_schema: dict = load_dag_schema_dict()["definitions"]["dag"]["properties"]
# The parameters we add manually in Serialization need to be ignored
ignored_keys: set = {
"_processor_dags_folder",
"tasks",
"has_on_success_callback",
"has_on_failure_callback",
"dag_dependencies",
"params",
}
keys_for_backwards_compat: set = {
"_concurrency",
}
dag_params: set = set(dag_schema.keys()) - ignored_keys - keys_for_backwards_compat
assert set(DAG.get_serialized_fields()) == dag_params
def test_operator_subclass_changing_base_defaults(self):
assert BaseOperator(task_id="dummy").do_xcom_push is True, (
"Precondition check! If this fails the test won't make sense"
)
class MyOperator(BaseOperator):
def __init__(self, do_xcom_push=False, **kwargs):
super().__init__(**kwargs)
self.do_xcom_push = do_xcom_push
op = MyOperator(task_id="dummy")
assert op.do_xcom_push is False
blob = SerializedBaseOperator.serialize_operator(op)
serialized_op = SerializedBaseOperator.deserialize_operator(blob)
assert serialized_op.do_xcom_push is False
def test_no_new_fields_added_to_base_operator(self):
"""
This test verifies that there are no new fields added to BaseOperator. And reminds that
tests should be added for it.
"""
from airflow.task.trigger_rule import TriggerRule
base_operator = BaseOperator(task_id="10")
# Return the name of any annotated class property, or anything explicitly listed in serialized fields
field_names = {
fld.name
for fld in dataclasses.fields(BaseOperator)
if fld.name in BaseOperator.get_serialized_fields()
} | BaseOperator.get_serialized_fields()
fields = {k: getattr(base_operator, k) for k in field_names}
assert fields == {
"_logger_name": None,
"_needs_expansion": None,
"_post_execute_hook": None,
"_pre_execute_hook": None,
"_task_display_name": None,
"allow_nested_operators": True,
"depends_on_past": False,
"do_xcom_push": True,
"doc": None,
"doc_json": None,
"doc_md": None,
"doc_rst": None,
"doc_yaml": None,
"downstream_task_ids": set(),
"end_date": None,
"email": None,
"email_on_failure": True,
"email_on_retry": True,
"execution_timeout": None,
"executor": None,
"executor_config": {},
"has_on_execute_callback": False,
"has_on_failure_callback": False,
"has_on_retry_callback": False,
"has_on_skipped_callback": False,
"has_on_success_callback": False,
"ignore_first_depends_on_past": False,
"is_setup": False,
"is_teardown": False,
"inlets": [],
"map_index_template": None,
"max_active_tis_per_dag": None,
"max_active_tis_per_dagrun": None,
"max_retry_delay": None,
"on_failure_fail_dagrun": False,
"outlets": [],
"owner": "airflow",
"params": {},
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 1,
"queue": "default",
"resources": None,
"retries": 0,
"retry_delay": timedelta(0, 300),
"retry_exponential_backoff": False,
"run_as_user": None,
"start_date": None,
"start_from_trigger": False,
"start_trigger_args": None,
"task_id": "10",
"task_type": "BaseOperator",
"template_ext": (),
"template_fields": (),
"template_fields_renderers": {},
"trigger_rule": TriggerRule.ALL_SUCCESS,
"ui_color": "#fff",
"ui_fgcolor": "#000",
"wait_for_downstream": False,
"wait_for_past_depends_before_skipping": False,
"weight_rule": _DownstreamPriorityWeightStrategy(),
"multiple_outputs": False,
}, """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY
Some fields were added to the BaseOperator! Please add them to the list above and make sure that
you add support for DAG serialization - you should add the field to
`airflow/serialization/schema.json` - they should have correct type defined there.
Note that we do not support versioning yet so you should only add optional fields to BaseOperator.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
"""
def test_operator_deserialize_old_names(self):
blob = {
"task_id": "custom_task",
"_downstream_task_ids": ["foo"],
"template_ext": [],
"template_fields": ["bash_command"],
"template_fields_renderers": {},
"task_type": "CustomOperator",
"_task_module": "tests_common.test_utils.mock_operators",
"pool": "default_pool",
"ui_color": "#fff",
"ui_fgcolor": "#000",
}
SerializedDAG._json_schema.validate(blob, _schema=load_dag_schema_dict()["definitions"]["operator"])
serialized_op = SerializedBaseOperator.deserialize_operator(blob)
assert serialized_op.downstream_task_ids == {"foo"}
def test_task_resources(self):
"""
Test task resources serialization/deserialization.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
logical_date = datetime(2020, 1, 1)
task_id = "task1"
with DAG("test_task_resources", schedule=None, start_date=logical_date) as dag:
task = EmptyOperator(task_id=task_id, resources={"cpus": 0.1, "ram": 2048})
SerializedDAG.validate_schema(SerializedDAG.to_dict(dag))
json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
deserialized_task = json_dag.get_task(task_id)
assert deserialized_task.resources == task.resources
assert isinstance(deserialized_task.resources, Resources)
def test_task_group_serialization(self):
"""
Test TaskGroup serialization/deserialization.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
logical_date = datetime(2020, 1, 1)
with DAG("test_task_group_serialization", schedule=None, start_date=logical_date) as dag:
task1 = EmptyOperator(task_id="task1")
with TaskGroup("group234") as group234:
_ = EmptyOperator(task_id="task2")
with TaskGroup("group34") as group34:
_ = EmptyOperator(task_id="task3")
_ = EmptyOperator(task_id="task4")
task5 = EmptyOperator(task_id="task5")
task1 >> group234
group34 >> task5
dag_dict = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(dag_dict)
json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(json_dag, dag)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
assert serialized_dag.task_group.children
assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys()
def check_task_group(node):
assert node.dag is serialized_dag
try:
children = node.children.values()
except AttributeError:
# Round-trip serialization and check the result
expected_serialized = SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id))
expected_deserialized = SerializedBaseOperator.deserialize_operator(expected_serialized)
expected_dict = SerializedBaseOperator.serialize_operator(expected_deserialized)
assert node
assert SerializedBaseOperator.serialize_operator(node) == expected_dict
return
for child in children:
check_task_group(child)
check_task_group(serialized_dag.task_group)
@staticmethod
def assert_taskgroup_children(se_task_group, dag_task_group, expected_children):
assert se_task_group.children.keys() == dag_task_group.children.keys() == expected_children
@staticmethod
def assert_task_is_setup_teardown(task, is_setup: bool = False, is_teardown: bool = False):
assert task.is_setup == is_setup
assert task.is_teardown == is_teardown
def test_setup_teardown_tasks(self):
"""
Test setup and teardown task serialization/deserialization.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
logical_date = datetime(2020, 1, 1)
with DAG(
"test_task_group_setup_teardown_tasks",
schedule=None,
start_date=logical_date,
) as dag:
EmptyOperator(task_id="setup").as_setup()
EmptyOperator(task_id="teardown").as_teardown()
with TaskGroup("group1"):
EmptyOperator(task_id="setup1").as_setup()
EmptyOperator(task_id="task1")
EmptyOperator(task_id="teardown1").as_teardown()
with TaskGroup("group2"):
EmptyOperator(task_id="setup2").as_setup()
EmptyOperator(task_id="task2")
EmptyOperator(task_id="teardown2").as_teardown()
dag_dict = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(dag_dict)
json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(json_dag, dag)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
self.assert_taskgroup_children(
serialized_dag.task_group, dag.task_group, {"setup", "teardown", "group1"}
)
self.assert_task_is_setup_teardown(serialized_dag.task_group.children["setup"], is_setup=True)
self.assert_task_is_setup_teardown(serialized_dag.task_group.children["teardown"], is_teardown=True)
se_first_group = serialized_dag.task_group.children["group1"]
dag_first_group = dag.task_group.children["group1"]
self.assert_taskgroup_children(
se_first_group,
dag_first_group,
{"group1.setup1", "group1.task1", "group1.group2", "group1.teardown1"},
)
self.assert_task_is_setup_teardown(se_first_group.children["group1.setup1"], is_setup=True)
self.assert_task_is_setup_teardown(se_first_group.children["group1.task1"])
self.assert_task_is_setup_teardown(se_first_group.children["group1.teardown1"], is_teardown=True)
se_second_group = se_first_group.children["group1.group2"]
dag_second_group = dag_first_group.children["group1.group2"]
self.assert_taskgroup_children(
se_second_group,
dag_second_group,
{"group1.group2.setup2", "group1.group2.task2", "group1.group2.teardown2"},
)
self.assert_task_is_setup_teardown(se_second_group.children["group1.group2.setup2"], is_setup=True)
self.assert_task_is_setup_teardown(se_second_group.children["group1.group2.task2"])
self.assert_task_is_setup_teardown(
se_second_group.children["group1.group2.teardown2"], is_teardown=True
)
@pytest.mark.db_test
def test_teardown_task_on_failure_fail_dagrun_serialization(self, dag_maker):
with dag_maker() as dag:
@teardown(on_failure_fail_dagrun=True)
def mytask():
print(1)
mytask()
dag_dict = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(dag_dict)
json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(json_dag, dag)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
task = serialized_dag.task_group.children["mytask"]
assert task.is_teardown is True
assert task.on_failure_fail_dagrun is True
@pytest.mark.db_test
def test_basic_mapped_dag(self, dag_maker):
dagbag = DagBag(
"airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py", include_examples=False
)
assert not dagbag.import_errors
dag = dagbag.dags["example_dynamic_task_mapping"]
ser_dag = SerializedDAG.to_dict(dag)
# We should not include `_is_sensor` most of the time (as it would be wasteful). Check we don't
assert "_is_sensor" not in ser_dag["dag"]["tasks"][0]["__var"]
SerializedDAG.validate_schema(ser_dag)
@pytest.mark.db_test
def test_teardown_mapped_serialization(self, dag_maker):
with dag_maker() as dag:
@teardown(on_failure_fail_dagrun=True)
def mytask(val=None):
print(1)
mytask.expand(val=[1, 2, 3])
task = dag.task_group.children["mytask"]
assert task.partial_kwargs["is_teardown"] is True
assert task.partial_kwargs["on_failure_fail_dagrun"] is True
dag_dict = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(dag_dict)
json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(json_dag, dag)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
task = serialized_dag.task_group.children["mytask"]
assert task.partial_kwargs["is_teardown"] is True
assert task.partial_kwargs["on_failure_fail_dagrun"] is True
def test_serialize_mapped_outlets(self):
with DAG(dag_id="d", schedule=None, start_date=datetime.now()):
op = MockOperator.partial(task_id="x").expand(arg1=[1, 2])
assert op.inlets == []
assert op.outlets == []
serialized = SerializedBaseOperator.serialize_mapped_operator(op)
assert "inlets" not in serialized
assert "outlets" not in serialized
round_tripped = SerializedBaseOperator.deserialize_operator(serialized)
assert isinstance(round_tripped, MappedOperator)
assert round_tripped.inlets == []
assert round_tripped.outlets == []
@pytest.mark.db_test
@pytest.mark.parametrize("mapped", [False, True])
def test_derived_dag_deps_sensor(self, mapped):
"""
Tests DAG dependency detection for sensors, including derived classes
"""
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.sensors.external_task import ExternalTaskSensor
class DerivedSensor(ExternalTaskSensor):
pass
logical_date = datetime(2020, 1, 1)
for class_ in [ExternalTaskSensor, DerivedSensor]:
with DAG(dag_id="test_derived_dag_deps_sensor", schedule=None, start_date=logical_date) as dag:
if mapped:
task1 = class_.partial(
task_id="task1",
external_dag_id="external_dag_id",
mode="reschedule",
).expand(external_task_id=["some_task", "some_other_task"])
else:
task1 = class_(
task_id="task1",
external_dag_id="external_dag_id",
mode="reschedule",
)
task2 = EmptyOperator(task_id="task2")
task1 >> task2
dag = SerializedDAG.to_dict(dag)
assert dag["dag"]["dag_dependencies"] == [
{
"source": "external_dag_id",
"target": "test_derived_dag_deps_sensor",
"label": "task1",
"dependency_type": "sensor",
"dependency_id": "task1",
}
]
@pytest.mark.db_test
def test_dag_deps_assets_with_duplicate_asset(self, testing_assets):
"""
Check that dag_dependencies node is populated correctly for a DAG with duplicate assets.
"""
from airflow.providers.standard.sensors.external_task import ExternalTaskSensor
logical_date = datetime(2020, 1, 1)
with DAG(dag_id="test", start_date=logical_date, schedule=[testing_assets[0]] * 5) as dag:
ExternalTaskSensor(
task_id="task1",
external_dag_id="external_dag_id",
mode="reschedule",
)
BashOperator(
task_id="asset_writer",
bash_command="echo hello",
outlets=[testing_assets[1]] * 3 + testing_assets[2:3],
)
@dag.task(outlets=[testing_assets[3]])
def other_asset_writer(x):
pass
other_asset_writer.expand(x=[1, 2])
testing_asset_key_strs = [AssetUniqueKey.from_asset(asset).to_str() for asset in testing_assets]
dag = SerializedDAG.to_dict(dag)
actual = sorted(dag["dag"]["dag_dependencies"], key=lambda x: tuple(x.values()))
expected = sorted(
[
{
"source": "test",
"target": "asset",
"label": "asset4",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[3],
},
{
"source": "external_dag_id",
"target": "test",
"label": "task1",
"dependency_type": "sensor",
"dependency_id": "task1",
},
{
"source": "test",
"target": "asset",
"label": "asset3",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[2],
},
{
"source": "test",
"target": "asset",
"label": "asset2",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[1],
},
{
"source": "asset",
"target": "test",
"label": "asset1",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[0],
},
{
"source": "asset",
"target": "test",
"label": "asset1",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[0],
},
{
"source": "asset",
"target": "test",
"label": "asset1",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[0],
},
{
"source": "asset",
"target": "test",
"label": "asset1",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[0],
},
{
"source": "asset",
"target": "test",
"label": "asset1",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[0],
},
],
key=lambda x: tuple(x.values()),
)
assert actual == expected
@pytest.mark.db_test
def test_dag_deps_assets(self, testing_assets):
"""
Check that dag_dependencies node is populated correctly for a DAG with assets.
Note that asset id will not be stored at this stage and will be later evaluated when
calling SerializedDagModel.get_dag_dependencies.
"""
from airflow.providers.standard.sensors.external_task import ExternalTaskSensor
logical_date = datetime(2020, 1, 1)
with DAG(dag_id="test", start_date=logical_date, schedule=testing_assets[0:1]) as dag:
ExternalTaskSensor(
task_id="task1",
external_dag_id="external_dag_id",
mode="reschedule",
)
BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=testing_assets[1:3])
@dag.task(outlets=testing_assets[3:])
def other_asset_writer(x):
pass
other_asset_writer.expand(x=[1, 2])
testing_asset_key_strs = [AssetUniqueKey.from_asset(asset).to_str() for asset in testing_assets]
dag = SerializedDAG.to_dict(dag)
actual = sorted(dag["dag"]["dag_dependencies"], key=lambda x: tuple(x.values()))
expected = sorted(
[
{
"source": "test",
"target": "asset",
"label": "asset4",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[3],
},
{
"source": "external_dag_id",
"target": "test",
"label": "task1",
"dependency_type": "sensor",
"dependency_id": "task1",
},
{
"source": "test",
"target": "asset",
"label": "asset3",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[2],
},
{
"source": "test",
"target": "asset",
"label": "asset2",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[1],
},
{
"source": "asset",
"target": "test",
"label": "asset1",
"dependency_type": "asset",
"dependency_id": testing_asset_key_strs[0],
},
],
key=lambda x: tuple(x.values()),
)
assert actual == expected
@pytest.mark.parametrize("mapped", [False, True])
def test_derived_dag_deps_operator(self, mapped):
"""
Tests DAG dependency detection for operators, including derived classes
"""
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.trigger_dagrun import (
TriggerDagRunOperator,
)
class DerivedOperator(TriggerDagRunOperator):
pass
logical_date = datetime(2020, 1, 1)
for class_ in [TriggerDagRunOperator, DerivedOperator]:
with DAG(
dag_id="test_derived_dag_deps_trigger",
schedule=None,
start_date=logical_date,
) as dag:
task1 = EmptyOperator(task_id="task1")
if mapped:
task2 = class_.partial(
task_id="task2",
trigger_dag_id="trigger_dag_id",
).expand(trigger_run_id=["one", "two"])
else:
task2 = class_(
task_id="task2",
trigger_dag_id="trigger_dag_id",
)
task1 >> task2
dag = SerializedDAG.to_dict(dag)
assert dag["dag"]["dag_dependencies"] == [
{
"source": "test_derived_dag_deps_trigger",
"target": "trigger_dag_id",
"label": "task2",
"dependency_type": "trigger",
"dependency_id": "task2",
}
]
def test_task_group_sorted(self):
"""
Tests serialize_task_group, make sure the list is in order
"""
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import TaskGroupSerialization
"""
start
╱ ╲
╱ ╲
task_group_up1 task_group_up2
(task_up1) (task_up2)
╲ ╱
task_group_middle
(task_middle)
╱ ╲
task_group_down1 task_group_down2
(task_down1) (task_down2)
╲ ╱
╲ ╱
end
"""
logical_date = datetime(2020, 1, 1)
with DAG(dag_id="test_task_group_sorted", schedule=None, start_date=logical_date) as dag:
start = EmptyOperator(task_id="start")
with TaskGroup("task_group_up1") as task_group_up1:
_ = EmptyOperator(task_id="task_up1")
with TaskGroup("task_group_up2") as task_group_up2:
_ = EmptyOperator(task_id="task_up2")
with TaskGroup("task_group_middle") as task_group_middle:
_ = EmptyOperator(task_id="task_middle")
with TaskGroup("task_group_down1") as task_group_down1:
_ = EmptyOperator(task_id="task_down1")
with TaskGroup("task_group_down2") as task_group_down2:
_ = EmptyOperator(task_id="task_down2")
end = EmptyOperator(task_id="end")
start >> task_group_up1
start >> task_group_up2
task_group_up1 >> task_group_middle
task_group_up2 >> task_group_middle
task_group_middle >> task_group_down1
task_group_middle >> task_group_down2
task_group_down1 >> end
task_group_down2 >> end
task_group_middle_dict = TaskGroupSerialization.serialize_task_group(
dag.task_group.children["task_group_middle"]
)
upstream_group_ids = task_group_middle_dict["upstream_group_ids"]
assert upstream_group_ids == ["task_group_up1", "task_group_up2"]
upstream_task_ids = task_group_middle_dict["upstream_task_ids"]
assert upstream_task_ids == [
"task_group_up1.task_up1",
"task_group_up2.task_up2",
]
downstream_group_ids = task_group_middle_dict["downstream_group_ids"]
assert downstream_group_ids == ["task_group_down1", "task_group_down2"]
task_group_down1_dict = TaskGroupSerialization.serialize_task_group(
dag.task_group.children["task_group_down1"]
)
downstream_task_ids = task_group_down1_dict["downstream_task_ids"]
assert downstream_task_ids == ["end"]
def test_edge_info_serialization(self):
"""
Tests edge_info serialization/deserialization.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import Label
with DAG(
"test_edge_info_serialization",
schedule=None,
start_date=datetime(2020, 1, 1),
) as dag:
task1 = EmptyOperator(task_id="task1")
task2 = EmptyOperator(task_id="task2")
task1 >> Label("test label") >> task2
dag_dict = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(dag_dict)
json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(json_dag, dag)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
assert serialized_dag.edge_info == dag.edge_info
@pytest.mark.db_test
@pytest.mark.parametrize("mode", ["poke", "reschedule"])
def test_serialize_sensor(self, mode):
from airflow.sdk.bases.sensor import BaseSensorOperator
class DummySensor(BaseSensorOperator):
def poke(self, context: Context):
return False
op = DummySensor(task_id="dummy", mode=mode, poke_interval=23)
blob = SerializedBaseOperator.serialize_operator(op)
assert "_is_sensor" in blob
serialized_op = SerializedBaseOperator.deserialize_operator(blob)
assert serialized_op.reschedule == (mode == "reschedule")
assert ReadyToRescheduleDep in [type(d) for d in serialized_op.deps]
@pytest.mark.parametrize("mode", ["poke", "reschedule"])
def test_serialize_mapped_sensor_has_reschedule_dep(self, mode):
from airflow.sdk.bases.sensor import BaseSensorOperator
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
class DummySensor(BaseSensorOperator):
def poke(self, context: Context):
return False
op = DummySensor.partial(task_id="dummy", mode=mode).expand(poke_interval=[23])
blob = SerializedBaseOperator.serialize_mapped_operator(op)
assert "_is_sensor" in blob
assert "_is_mapped" in blob
serialized_op = SerializedBaseOperator.deserialize_operator(blob)
assert ReadyToRescheduleDep in [type(d) for d in serialized_op.deps]
@pytest.mark.parametrize(
("passed_success_callback", "expected_value"),
[
({"on_success_callback": lambda x: print("hi")}, True),
({}, False),
],
)
def test_dag_on_success_callback_roundtrip(self, passed_success_callback, expected_value):
"""
Test that when on_success_callback is passed to the DAG, has_on_success_callback is stored
in Serialized JSON blob. And when it is de-serialized dag.has_on_success_callback is set to True.
When the callback is not set, has_on_success_callback should not be stored in Serialized blob
and so default to False on de-serialization
"""
dag = DAG(
dag_id="test_dag_on_success_callback_roundtrip",
schedule=None,
**passed_success_callback,
)
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1))
serialized_dag = SerializedDAG.to_dict(dag)
if expected_value:
assert "has_on_success_callback" in serialized_dag["dag"]
else:
assert "has_on_success_callback" not in serialized_dag["dag"]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
assert deserialized_dag.has_on_success_callback is expected_value
@pytest.mark.parametrize(
("passed_failure_callback", "expected_value"),
[
({"on_failure_callback": lambda x: print("hi")}, True),
({}, False),
],
)
def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expected_value):
"""
Test that when on_failure_callback is passed to the DAG, has_on_failure_callback is stored
in Serialized JSON blob. And when it is de-serialized dag.has_on_failure_callback is set to True.
When the callback is not set, has_on_failure_callback should not be stored in Serialized blob
and so default to False on de-serialization
"""
dag = DAG(
dag_id="test_dag_on_failure_callback_roundtrip",
schedule=None,
**passed_failure_callback,
)
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1))
serialized_dag = SerializedDAG.to_dict(dag)
if expected_value:
assert "has_on_failure_callback" in serialized_dag["dag"]
else:
assert "has_on_failure_callback" not in serialized_dag["dag"]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
assert deserialized_dag.has_on_failure_callback is expected_value
@pytest.mark.parametrize(
("dag_arg", "conf_arg", "expected"),
[
(True, "True", True),
(True, "False", True),
(False, "True", False),
(False, "False", False),
(None, "True", True),
(None, "False", False),
],
)
def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expected):
"""
Test that when disable_bundle_versioning is passed to the DAG, has_disable_bundle_versioning is stored
in Serialized JSON blob. And when it is de-serialized dag.has_disable_bundle_versioning is set to True.
When the callback is not set, has_disable_bundle_versioning should not be stored in Serialized blob
and so default to False on de-serialization
"""
with conf_vars({("dag_processor", "disable_bundle_versioning"): conf_arg}):
kwargs = {}
if dag_arg is not None:
kwargs["disable_bundle_versioning"] = dag_arg
dag = DAG(
dag_id="test_dag_disable_bundle_versioning_roundtrip",
schedule=None,
**kwargs,
)
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1))
serialized_dag = SerializedDAG.to_dict(dag)
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
assert deserialized_dag.disable_bundle_versioning is expected
@pytest.mark.parametrize(
("object_to_serialized", "expected_output"),
[
(
["task_1", "task_5", "task_2", "task_4"],
["task_1", "task_5", "task_2", "task_4"],
),
(
{"task_1", "task_5", "task_2", "task_4"},
["task_1", "task_2", "task_4", "task_5"],
),
(
("task_1", "task_5", "task_2", "task_4"),
["task_1", "task_5", "task_2", "task_4"],
),
(
{
"staging_schema": [
{"key:": "foo", "value": "bar"},
{"key:": "this", "value": "that"},
"test_conf",
]
},
{
"staging_schema": [
{"__type": "dict", "__var": {"key:": "foo", "value": "bar"}},
{
"__type": "dict",
"__var": {"key:": "this", "value": "that"},
},
"test_conf",
]
},
),
(
{"task3": "test3", "task2": "test2", "task1": "test1"},
{"task1": "test1", "task2": "test2", "task3": "test3"},
),
(
("task_1", "task_5", "task_2", 3, ["x", "y"]),
["task_1", "task_5", "task_2", 3, ["x", "y"]],
),
],
)
def test_serialized_objects_are_sorted(self, object_to_serialized, expected_output):
"""Test Serialized Sets are sorted while list and tuple preserve order"""
serialized_obj = SerializedDAG.serialize(object_to_serialized)
if isinstance(serialized_obj, dict) and "__type" in serialized_obj:
serialized_obj = serialized_obj["__var"]
assert serialized_obj == expected_output
def test_params_upgrade(self):
"""When pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param"""
serialized = {
"__version": 3,
"dag": {
"dag_id": "simple_dag",
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": {"none": None, "str": "str", "dict": {"a": "b"}},
},
}
dag = SerializedDAG.from_dict(serialized)
assert dag.params["none"] is None
# After decoupling, server-side deserialization uses SerializedParam
assert isinstance(dag.params.get_param("none"), SerializedParam)
assert dag.params["str"] == "str"
def test_params_serialization_from_dict_upgrade(self):
"""
In <=2.9.2 params were serialized as a JSON object instead of a list of key-value pairs.
This test asserts that the params are still deserialized properly.
"""
serialized = {
"__version": 3,
"dag": {
"dag_id": "simple_dag",
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": {
"my_param": {
"__class": "airflow.models.param.Param",
"default": "str",
}
},
},
}
dag = SerializedDAG.from_dict(serialized)
param = dag.params.get_param("my_param")
# After decoupling, server-side deserialization uses SerializedParam
assert isinstance(param, SerializedParam)
assert param.value == "str"
def test_params_serialize_default_2_2_0(self):
"""
In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this
test only to ensure that params stored in 2.2.0 can still be parsed correctly.
"""
serialized = {
"__version": 3,
"dag": {
"dag_id": "simple_dag",
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": [["str", {"__class": "airflow.models.param.Param", "default": "str"}]],
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
# After decoupling, server-side deserialization uses SerializedParam
assert isinstance(dag.params.get_param("str"), SerializedParam)
assert dag.params["str"] == "str"
def test_params_serialize_default(self):
serialized = {
"__version": 3,
"dag": {
"dag_id": "simple_dag",
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": [
[
"my_param",
{
"default": "a string value",
"description": "hello",
"schema": {"__var": {"type": "string"}, "__type": "dict"},
"__class": "airflow.models.param.Param",
},
]
],
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
assert dag.params["my_param"] == "a string value"
param = dag.params.get_param("my_param")
# After decoupling, server-side deserialization uses SerializedParam
assert isinstance(param, SerializedParam)
assert param.description == "hello"
assert param.schema == {"type": "string"}
@pytest.mark.db_test
def test_not_templateable_fields_in_serialized_dag(self):
"""
Test that when we use not templateable fields, an Airflow exception is raised.
"""
class TestOperator(BaseOperator):
template_fields = (
"execution_timeout", # not templateable
"run_as_user", # templateable
)
def execute(self, context: Context):
pass
dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime(2023, 11, 9))
with dag:
task = TestOperator(
task_id="test_task",
run_as_user="{{ test_run_as_user }}",
execution_timeout=timedelta(seconds=10),
)
task.render_template_fields(context={"test_run_as_user": "foo"})
assert task.run_as_user == "foo"
with pytest.raises(
AirflowException,
match=re.escape(
dedent(
"""Failed to serialize DAG 'test_dag': Cannot template BaseOperator field:
'execution_timeout' op.__class__.__name__='TestOperator' op.template_fields=('execution_timeout', 'run_as_user')"""
)
),
):
SerializedDAG.to_dict(dag)
@pytest.mark.db_test
def test_start_trigger_args_in_serialized_dag(self):
"""
Test that when we provide start_trigger_args, the DAG can be correctly serialized.
"""
class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"delta": timedelta(seconds=1)},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_trigger_args.trigger_kwargs = {"delta": timedelta(seconds=2)}
self.start_from_trigger = True
def execute_complete(self):
pass
class Test2Operator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.testing.SuccessTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def execute_complete(self):
pass
dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime(2023, 11, 9))
with dag:
TestOperator(task_id="test_task_1")
Test2Operator(task_id="test_task_2")
serialized_obj = SerializedDAG.to_dict(dag)
tasks = serialized_obj["dag"]["tasks"]
assert tasks[0]["__var"]["start_trigger_args"] == {
"__type": "START_TRIGGER_ARGS",
"trigger_cls": "airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
# "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}},
"trigger_kwargs": {
"__type": "dict",
"__var": {"delta": {"__type": "timedelta", "__var": 2.0}},
},
"next_method": "execute_complete",
"next_kwargs": None,
"timeout": None,
}
assert tasks[0]["__var"]["start_from_trigger"] is True
assert tasks[1]["__var"]["start_trigger_args"] == {
"__type": "START_TRIGGER_ARGS",
"trigger_cls": "airflow.triggers.testing.SuccessTrigger",
"trigger_kwargs": {"__type": "dict", "__var": {}},
"next_method": "execute_complete",
"next_kwargs": None,
"timeout": None,
}
assert tasks[1]["__var"]["start_from_trigger"] is True
def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
def mock__import__(name, globals_=None, locals_=None, fromlist=(), level=0):
if level == 0 and name.partition(".")[0] == "kubernetes":
raise ImportError("No module named 'kubernetes'")
return importlib.__import__(name, globals=globals_, locals=locals_, fromlist=fromlist, level=level)
with mock.patch("builtins.__import__", side_effect=mock__import__) as import_mock:
# load module from scratch, this does not replace any already imported
# airflow.serialization.serialized_objects module in sys.modules
spec = importlib.util.find_spec("airflow.serialization.serialized_objects")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# if we got this far, the module did not try to load kubernetes, but
# did it try to access airflow.providers.cncf.kubernetes.*?
imported_airflow = {
c.args[0].split(".", 2)[1] for c in import_mock.call_args_list if c.args[0].startswith("airflow.")
}
assert "kubernetes" not in imported_airflow
# pod loading is not supported when kubernetes is not available
pod_override = {
"__type": "k8s.V1Pod",
"__var": PodGenerator.serialize_pod(executor_config_pod),
}
# we should error if attempting to deserialize POD without kubernetes installed
with pytest.raises(RuntimeError, match="Cannot deserialize POD objects without kubernetes"):
module.BaseSerialization.from_dict(pod_override)
# basic serialization should succeed
module.SerializedDAG.to_dict(make_simple_dag())
def test_operator_expand_serde():
literal = [1, 2, {"a": "b"}]
real_op = BashOperator.partial(task_id="a", executor_config={"dict": {"sub": "value"}}).expand(
bash_command=literal
)
serialized = BaseSerialization.serialize(real_op)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "airflow.providers.standard.operators.bash",
"task_type": "BashOperator",
"expand_input": {
"type": "dict-of-lists",
"value": {
"__type": "dict",
"__var": {"bash_command": [1, 2, {"__type": "dict", "__var": {"a": "b"}}]},
},
},
"partial_kwargs": {
"executor_config": {
"__type": "dict",
"__var": {"dict": {"__type": "dict", "__var": {"sub": "value"}}},
},
"retry_delay": {"__type": "timedelta", "__var": 300.0},
},
"task_id": "a",
"template_fields": ["bash_command", "env", "cwd"],
"template_ext": [".sh", ".bash"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"ui_color": "#f0ede4",
"_disallow_kwargs_override": False,
"_expand_input_attr": "expand_input",
}
op = BaseSerialization.deserialize(serialized)
assert isinstance(op, MappedOperator)
# operator_class now stores only minimal type information for memory efficiency
assert op.operator_class == {
"task_type": "BashOperator",
"_operator_name": "BashOperator",
}
assert op.expand_input.value["bash_command"] == literal
assert op.partial_kwargs["executor_config"] == {"dict": {"sub": "value"}}
def test_operator_expand_xcomarg_serde():
from airflow.models.xcom_arg import SchedulerPlainXComArg
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _XComRef
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="task_2").expand(arg2=XComArg(task1))
serialized = BaseSerialization.serialize(mapped)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "tests_common.test_utils.mock_operators",
"task_type": "MockOperator",
"expand_input": {
"type": "dict-of-lists",
"value": {
"__type": "dict",
"__var": {
"arg2": {
"__type": "xcomref",
"__var": {"task_id": "op1", "key": "return_value"},
}
},
},
},
"partial_kwargs": {
"retry_delay": {"__type": "timedelta", "__var": 300.0},
},
"task_id": "task_2",
"template_fields": ["arg1", "arg2"],
"_disallow_kwargs_override": False,
"_expand_input_attr": "expand_input",
}
op = BaseSerialization.deserialize(serialized)
# The XComArg can't be deserialized before the DAG is.
xcom_ref = op.expand_input.value["arg2"]
assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value["arg2"]
assert isinstance(xcom_arg, SchedulerPlainXComArg)
assert xcom_arg.operator is serialized_dag.task_dict["op1"]
@pytest.mark.parametrize("strict", [True, False])
def test_operator_expand_kwargs_literal_serde(strict):
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS, _XComRef
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="task_2").expand_kwargs(
[{"a": "x"}, {"a": XComArg(task1)}],
strict=strict,
)
serialized = BaseSerialization.serialize(mapped)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "tests_common.test_utils.mock_operators",
"task_type": "MockOperator",
"expand_input": {
"type": "list-of-dicts",
"value": [
{"__type": "dict", "__var": {"a": "x"}},
{
"__type": "dict",
"__var": {
"a": {
"__type": "xcomref",
"__var": {"task_id": "op1", "key": "return_value"},
}
},
},
],
},
"partial_kwargs": {
"retry_delay": {"__type": "timedelta", "__var": 300.0},
},
"task_id": "task_2",
"template_fields": ["arg1", "arg2"],
"_disallow_kwargs_override": strict,
"_expand_input_attr": "expand_input",
}
op = BaseSerialization.deserialize(serialized)
assert op.deps == DEFAULT_OPERATOR_DEPS
assert op._disallow_kwargs_override == strict
# The XComArg can't be deserialized before the DAG is.
expand_value = op.expand_input.value
assert expand_value == [
{"a": "x"},
{"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})},
]
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
resolved_expand_value = serialized_dag.task_dict["task_2"].expand_input.value
assert resolved_expand_value == [
{"a": "x"},
{"a": _XComRef({"task_id": "op1", "key": "return_value"})},
]
@pytest.mark.parametrize("strict", [True, False])
def test_operator_expand_kwargs_xcomarg_serde(strict):
from airflow.models.xcom_arg import SchedulerPlainXComArg
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _XComRef
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="task_2").expand_kwargs(XComArg(task1), strict=strict)
serialized = SerializedBaseOperator.serialize(mapped)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "tests_common.test_utils.mock_operators",
"task_type": "MockOperator",
"expand_input": {
"type": "list-of-dicts",
"value": {
"__type": "xcomref",
"__var": {"task_id": "op1", "key": "return_value"},
},
},
"partial_kwargs": {
"retry_delay": {"__type": "timedelta", "__var": 300.0},
},
"task_id": "task_2",
"template_fields": ["arg1", "arg2"],
"_disallow_kwargs_override": strict,
"_expand_input_attr": "expand_input",
}
op = BaseSerialization.deserialize(serialized)
assert op._disallow_kwargs_override == strict
# The XComArg can't be deserialized before the DAG is.
xcom_ref = op.expand_input.value
assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value
assert isinstance(xcom_arg, SchedulerPlainXComArg)
assert xcom_arg.operator is serialized_dag.task_dict["op1"]
def test_task_resources_serde():
"""
Test task resources serialization/deserialization.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
logical_date = datetime(2020, 1, 1)
task_id = "task1"
with DAG("test_task_resources", schedule=None, start_date=logical_date) as _:
task = EmptyOperator(task_id=task_id, resources={"cpus": 0.1, "ram": 2048})
serialized = BaseSerialization.serialize(task)
assert serialized["__var"]["resources"] == {
"cpus": {"name": "CPU", "qty": 0.1, "units_str": "core(s)"},
"disk": {"name": "Disk", "qty": 512, "units_str": "MB"},
"gpus": {"name": "GPU", "qty": 0, "units_str": "gpu(s)"},
"ram": {"name": "RAM", "qty": 2048, "units_str": "MB"},
}
@pytest.mark.parametrize("execution_timeout", [None, timedelta(hours=1)])
def test_task_execution_timeout_serde(execution_timeout):
"""
Test task execution_timeout serialization/deserialization.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
with DAG("test_task_execution_timeout", schedule=None, start_date=datetime(2020, 1, 1)) as _:
task = EmptyOperator(task_id="task1", execution_timeout=execution_timeout)
serialized = BaseSerialization.serialize(task)
if execution_timeout:
assert "execution_timeout" in serialized["__var"]
deserialized = BaseSerialization.deserialize(serialized)
assert deserialized.execution_timeout == task.execution_timeout
def test_taskflow_expand_serde():
from airflow.models.xcom_arg import XComArg
from airflow.sdk import task
from airflow.serialization.serialized_objects import _ExpandInputRef, _XComRef
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
op1 = BaseOperator(task_id="op1")
@task(retry_delay=30)
def x(arg1, arg2, arg3):
print(arg1, arg2, arg3)
print("**", type(x), type(x.partial), type(x.expand))
x.partial(arg1=[1, 2, {"a": "b"}]).expand(arg2={"a": 1, "b": 2}, arg3=XComArg(op1))
original = dag.get_task("x")
serialized = BaseSerialization.serialize(original)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "airflow.providers.standard.decorators.python",
"task_type": "_PythonDecoratedOperator",
"_operator_name": "@task",
"partial_kwargs": {
"op_args": [],
"op_kwargs": {
"__type": "dict",
"__var": {"arg1": [1, 2, {"__type": "dict", "__var": {"a": "b"}}]},
},
"retry_delay": {"__type": "timedelta", "__var": 30.0},
},
"op_kwargs_expand_input": {
"type": "dict-of-lists",
"value": {
"__type": "dict",
"__var": {
"arg2": {"__type": "dict", "__var": {"a": 1, "b": 2}},
"arg3": {
"__type": "xcomref",
"__var": {"task_id": "op1", "key": "return_value"},
},
},
},
},
"ui_color": "#ffefeb",
"task_id": "x",
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
"template_fields_renderers": {
"templates_dict": "json",
"op_args": "py",
"op_kwargs": "py",
},
"_disallow_kwargs_override": False,
"_expand_input_attr": "op_kwargs_expand_input",
"python_callable_name": qualname(x),
}
deserialized = BaseSerialization.deserialize(serialized)
assert isinstance(deserialized, MappedOperator)
assert deserialized.upstream_task_ids == set()
assert deserialized.downstream_task_ids == set()
assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
key="dict-of-lists",
value={
"arg2": {"a": 1, "b": 2},
"arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
},
)
assert deserialized.partial_kwargs == {
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
}
# this dag is not pickleable in this context, so we have to simply
# set it to None
deserialized.dag = None
# Ensure the serialized operator can also be correctly pickled, to ensure
# correct interaction between DAG pickling and serialization. This is done
# here so we don't need to duplicate tests between pickled and non-pickled
# DAGs everywhere else.
pickled = pickle.loads(pickle.dumps(deserialized))
assert pickled.op_kwargs_expand_input == _ExpandInputRef(
key="dict-of-lists",
value={
"arg2": {"a": 1, "b": 2},
"arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
},
)
assert pickled.partial_kwargs == {
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
}
@pytest.mark.parametrize("strict", [True, False])
def test_taskflow_expand_kwargs_serde(strict):
from airflow.models.xcom_arg import XComArg
from airflow.sdk import task
from airflow.serialization.serialized_objects import _ExpandInputRef, _XComRef
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
op1 = BaseOperator(task_id="op1")
@task(retry_delay=30)
def x(arg1, arg2, arg3):
print(arg1, arg2, arg3)
x.partial(arg1=[1, 2, {"a": "b"}]).expand_kwargs(XComArg(op1), strict=strict)
original = dag.get_task("x")
serialized = BaseSerialization.serialize(original)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "airflow.providers.standard.decorators.python",
"task_type": "_PythonDecoratedOperator",
"_operator_name": "@task",
"python_callable_name": qualname(x),
"partial_kwargs": {
"op_args": [],
"op_kwargs": {
"__type": "dict",
"__var": {"arg1": [1, 2, {"__type": "dict", "__var": {"a": "b"}}]},
},
"retry_delay": {"__type": "timedelta", "__var": 30.0},
},
"op_kwargs_expand_input": {
"type": "list-of-dicts",
"value": {
"__type": "xcomref",
"__var": {"task_id": "op1", "key": "return_value"},
},
},
"ui_color": "#ffefeb",
"task_id": "x",
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
"template_fields_renderers": {
"templates_dict": "json",
"op_args": "py",
"op_kwargs": "py",
},
"_disallow_kwargs_override": strict,
"_expand_input_attr": "op_kwargs_expand_input",
}
deserialized = BaseSerialization.deserialize(serialized)
assert isinstance(deserialized, MappedOperator)
assert deserialized._disallow_kwargs_override == strict
assert deserialized.upstream_task_ids == set()
assert deserialized.downstream_task_ids == set()
assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
key="list-of-dicts",
value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
)
assert deserialized.partial_kwargs == {
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
}
# this dag is not pickleable in this context, so we have to simply
# set it to None
deserialized.dag = None
# Ensure the serialized operator can also be correctly pickled, to ensure
# correct interaction between DAG pickling and serialization. This is done
# here so we don't need to duplicate tests between pickled and non-pickled
# DAGs everywhere else.
pickled = pickle.loads(pickle.dumps(deserialized))
assert pickled.op_kwargs_expand_input == _ExpandInputRef(
"list-of-dicts",
_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
)
assert pickled.partial_kwargs == {
"op_args": [],
"op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
"retry_delay": timedelta(seconds=30),
}
def test_mapped_task_group_serde():
from airflow.models.expandinput import SchedulerDictOfListsExpandInput
from airflow.sdk.definitions.decorators.task_group import task_group
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
@task_group
def tg(a: str) -> None:
BaseOperator(task_id="op1")
with pytest.raises(NotImplementedError) as ctx:
BashOperator.partial(task_id="op2").expand(bash_command=["ls", a])
assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported"
tg.expand(a=[".", ".."])
ser_dag = SerializedBaseOperator.serialize(dag)
assert ser_dag[Encoding.VAR]["task_group"]["children"]["tg"] == (
"taskgroup",
{
"_group_id": "tg",
"children": {
"tg.op1": ("operator", "tg.op1"),
# "tg.op2": ("operator", "tg.op2"),
},
"downstream_group_ids": [],
"downstream_task_ids": [],
"expand_input": {
"type": "dict-of-lists",
"value": {"__type": "dict", "__var": {"a": [".", ".."]}},
},
"group_display_name": "",
"is_mapped": True,
"prefix_group_id": True,
"tooltip": "",
"ui_color": "CornflowerBlue",
"ui_fgcolor": "#000",
"upstream_group_ids": [],
"upstream_task_ids": [],
},
)
serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, SerializedTaskGroup)
assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": [".", ".."]})
@pytest.mark.db_test
def test_mapped_task_with_operator_extra_links_property():
class _DummyOperator(BaseOperator):
def __init__(self, inputs, **kwargs):
super().__init__(**kwargs)
self.inputs = inputs
@property
def operator_extra_links(self):
return (AirflowLink2(),)
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
assert serialized_dag[Encoding.VAR]["tasks"][0]["__var"] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",
"value": {"__type": "dict", "__var": {"inputs": [1, 2, 3]}},
},
"partial_kwargs": {
"retry_delay": {"__type": "timedelta", "__var": 300.0},
},
"_disallow_kwargs_override": False,
"_expand_input_attr": "expand_input",
"_operator_extra_links": {"airflow": "_link_AirflowLink2"},
"template_fields": [],
"task_type": "_DummyOperator",
"_task_module": "unit.serialization.test_dag_serialization",
"_is_mapped": True,
}
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
# operator defined links have to be instances of XComOperatorLink
assert deserialized_dag.task_dict["task"].operator_extra_links == [
XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2")
]
mapped_task = deserialized_dag.task_dict["task"]
assert mapped_task.operator_extra_link_dict == {
"airflow": XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2")
}
assert mapped_task.global_operator_extra_link_dict == {"airflow": AirflowLink(), "github": GithubLink()}
assert mapped_task.extra_links == sorted({"airflow", "github"})
def empty_function(*args, **kwargs):
"""Empty function for testing."""
def test_python_callable_in_partial_kwargs():
from airflow.providers.standard.operators.python import PythonOperator
operator = PythonOperator.partial(
task_id="task",
python_callable=empty_function,
).expand(op_kwargs=[{"x": 1}])
serialized = SerializedBaseOperator.serialize_mapped_operator(operator)
assert "python_callable" not in serialized["partial_kwargs"]
assert serialized["partial_kwargs"]["python_callable_name"] == qualname(empty_function)
deserialized = SerializedBaseOperator.deserialize_operator(serialized)
assert "python_callable" not in deserialized.partial_kwargs
assert deserialized.partial_kwargs["python_callable_name"] == qualname(empty_function)
def test_handle_v1_serdag():
v1 = {
"__version": 1,
"dag": {
"default_args": {
"__type": "dict",
"__var": {
"depends_on_past": False,
"retries": 1,
"retry_delay": {"__type": "timedelta", "__var": 240.0},
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
"sla": {"__type": "timedelta", "__var": 100.0},
},
},
"start_date": 1564617600.0,
"_task_group": {
"_group_id": None,
"prefix_group_id": True,
"children": {
"bash_task": ("operator", "bash_task"),
"custom_task": ("operator", "custom_task"),
},
"tooltip": "",
"ui_color": "CornflowerBlue",
"ui_fgcolor": "#000",
"upstream_group_ids": [],
"downstream_group_ids": [],
"upstream_task_ids": [],
"downstream_task_ids": [],
},
"is_paused_upon_creation": False,
"max_active_runs": 16,
"max_active_tasks": 16,
"max_consecutive_failed_dag_runs": 0,
"_dag_id": "simple_dag",
"deadline": None,
"doc_md": "### DAG Tutorial Documentation",
"fileloc": None,
"_processor_dags_folder": (
AIRFLOW_REPO_ROOT_PATH / "airflow-core" / "tests" / "unit" / "dags"
).as_posix(),
"tasks": [
{
"__type": "operator",
"__var": {
"task_id": "bash_task",
"retries": 1,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"sla": 100.0,
"downstream_task_ids": [],
"ui_color": "#f0ede4",
"ui_fgcolor": "#000",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"bash_command": "echo {{ task.task_id }}",
"_task_type": "BashOperator",
# Slightly difference from v2-10-stable here, we manually changed this path
"_task_module": "airflow.providers.standard.operators.bash",
"owner": "airflow1",
"pool": "pool1",
"task_display_name": "my_bash_task",
"is_setup": False,
"is_teardown": False,
"on_failure_fail_dagrun": False,
"executor_config": {
"__type": "dict",
"__var": {
"pod_override": {
"__type": "k8s.V1Pod",
"__var": PodGenerator.serialize_pod(executor_config_pod),
}
},
},
"doc_md": "### Task Tutorial Documentation",
"_log_config_logger_name": "airflow.task.operators",
"_needs_expansion": False,
"weight_rule": "downstream",
"start_trigger_args": None,
"start_from_trigger": False,
"inlets": [
{
"__type": "dataset",
"__var": {
"extra": {},
"uri": "asset-1",
},
},
{
"__type": "dataset_alias",
"__var": {"name": "alias-name"},
},
],
"outlets": [
{
"__type": "dataset",
"__var": {
"extra": {},
"uri": "asset-2",
},
},
],
},
},
{
"__type": "operator",
"__var": {
"task_id": "custom_task",
"retries": 1,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"sla": 100.0,
"downstream_task_ids": [],
"_operator_extra_links": [{"tests.test_utils.mock_operators.CustomOpLink": {}}],
"ui_color": "#fff",
"ui_fgcolor": "#000",
"template_ext": [],
"template_fields": ["bash_command"],
"template_fields_renderers": {},
"_task_type": "CustomOperator",
"_operator_name": "@custom",
# Slightly difference from v2-10-stable here, we manually changed this path
"_task_module": "tests_common.test_utils.mock_operators",
"pool": "default_pool",
"is_setup": False,
"is_teardown": False,
"on_failure_fail_dagrun": False,
"_log_config_logger_name": "airflow.task.operators",
"_needs_expansion": False,
"weight_rule": "downstream",
"start_trigger_args": None,
"start_from_trigger": False,
},
},
],
"schedule_interval": {"__type": "timedelta", "__var": 86400.0},
"timezone": "UTC",
"_access_control": {
"__type": "dict",
"__var": {
"test_role": {
"__type": "dict",
"__var": {
"DAGs": {
"__type": "set",
"__var": [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT],
}
},
}
},
},
"edge_info": {},
"dag_dependencies": [
# dataset as schedule (source)
{
"source": "dataset",
"target": "dag1",
"dependency_type": "dataset",
"dependency_id": "dataset_uri_1",
},
# dataset alias (resolved) as schedule (source)
{
"source": "dataset",
"target": "dataset-alias:alias_name_1",
"dependency_type": "dataset",
"dependency_id": "dataset_uri_2",
},
{
"source": "dataset:alias_name_1",
"target": "dag2",
"dependency_type": "dataset-alias",
"dependency_id": "alias_name_1",
},
# dataset alias (not resolved) as schedule (source)
{
"source": "dataset-alias",
"target": "dag2",
"dependency_type": "dataset-alias",
"dependency_id": "alias_name_2",
},
# dataset as outlets (target)
{
"source": "dag10",
"target": "dataset",
"dependency_type": "dataset",
"dependency_id": "dataset_uri_10",
},
# dataset alias (resolved) as outlets (target)
{
"source": "dag20",
"target": "dataset-alias:alias_name_10",
"dependency_type": "dataset",
"dependency_id": "dataset_uri_20",
},
{
"source": "dataset:dataset_uri_20",
"target": "dataset-alias",
"dependency_type": "dataset-alias",
"dependency_id": "alias_name_10",
},
# dataset alias (not resolved) as outlets (target)
{
"source": "dag2",
"target": "dataset-alias",
"dependency_type": "dataset-alias",
"dependency_id": "alias_name_2",
},
],
"params": [],
},
}
expected_dag_dependencies = [
# asset as schedule (source)
{
"dependency_id": "dataset_uri_1",
"dependency_type": "asset",
"label": "dataset_uri_1",
"source": "asset",
"target": "dag1",
},
# asset alias (resolved) as schedule (source)
{
"dependency_id": "dataset_uri_2",
"dependency_type": "asset",
"label": "dataset_uri_2",
"source": "asset",
"target": "asset-alias:alias_name_1",
},
{
"dependency_id": "alias_name_1",
"dependency_type": "asset-alias",
"label": "alias_name_1",
"source": "asset:alias_name_1",
"target": "dag2",
},
# asset alias (not resolved) as schedule (source)
{
"dependency_id": "alias_name_2",
"dependency_type": "asset-alias",
"label": "alias_name_2",
"source": "asset-alias",
"target": "dag2",
},
# asset as outlets (target)
{
"dependency_id": "dataset_uri_10",
"dependency_type": "asset",
"label": "dataset_uri_10",
"source": "dag10",
"target": "asset",
},
# asset alias (resolved) as outlets (target)
{
"dependency_id": "dataset_uri_20",
"dependency_type": "asset",
"label": "dataset_uri_20",
"source": "dag20",
"target": "asset-alias:alias_name_10",
},
{
"dependency_id": "alias_name_10",
"dependency_type": "asset-alias",
"label": "alias_name_10",
"source": "asset:dataset_uri_20",
"target": "asset-alias",
},
# asset alias (not resolved) as outlets (target)
{
"dependency_id": "alias_name_2",
"dependency_type": "asset-alias",
"label": "alias_name_2",
"source": "dag2",
"target": "asset-alias",
},
]
SerializedDAG.conversion_v1_to_v2(v1)
SerializedDAG.conversion_v2_to_v3(v1)
dag = SerializedDAG.from_dict(v1)
expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth)
expected = SerializedDAG.from_dict(expected_sdag)
fields_to_verify = set(vars(expected).keys()) - {
"task_group", # Tested separately
"dag_dependencies", # Tested separately
"last_loaded", # Dynamically set to utcnow
}
for f in fields_to_verify:
dag_value = getattr(dag, f)
expected_value = getattr(expected, f)
assert dag_value == expected_value, (
f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} != V3={expected_value!r}"
)
for f in set(vars(expected.task_group).keys()) - {"dag"}:
dag_tg_value = getattr(dag.task_group, f)
expected_tg_value = getattr(expected.task_group, f)
assert dag_tg_value == expected_tg_value, (
f"V2 task_group field '{f}' differs: V2={dag_tg_value!r} != V3={expected_tg_value!r}"
)
assert getattr(dag, "dag_dependencies") == expected_dag_dependencies
def test_handle_v2_serdag():
"""Test that v2 serialized DAGs can be deserialized properly."""
v2 = {
"__version": 2,
"dag": {
"default_args": {
"__type": "dict",
"__var": {
"depends_on_past": False,
"retries": 1,
"retry_delay": {"__type": "timedelta", "__var": 240.0},
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
},
},
"start_date": 1564617600.0,
"timetable": {
"__type": "airflow.timetables.interval.DeltaDataIntervalTimetable",
"__var": {
"delta": 86400.0,
},
},
"task_group": {
"_group_id": None,
"group_display_name": "",
"prefix_group_id": True,
"children": {
"bash_task": ("operator", "bash_task"),
"custom_task": ("operator", "custom_task"),
},
"tooltip": "",
"ui_color": "CornflowerBlue",
"ui_fgcolor": "#000",
"upstream_group_ids": [],
"downstream_group_ids": [],
"upstream_task_ids": [],
"downstream_task_ids": [],
},
"is_paused_upon_creation": False,
"dag_id": "simple_dag",
"catchup": False,
"disable_bundle_versioning": False,
"doc_md": "### DAG Tutorial Documentation",
"fileloc": None,
"_processor_dags_folder": (
AIRFLOW_REPO_ROOT_PATH / "airflow-core" / "tests" / "unit" / "dags"
).as_posix(),
"tasks": [
{
"__type": "operator",
"__var": {
"task_id": "bash_task",
"retries": 1,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"downstream_task_ids": [],
"ui_color": "#f0ede4",
"ui_fgcolor": "#000",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {
"bash_command": "bash",
"env": "json",
},
"bash_command": "echo {{ task.task_id }}",
"task_type": "BashOperator",
"_task_module": "airflow.providers.standard.operators.bash",
"owner": "airflow1",
"pool": "pool1",
"is_setup": False,
"is_teardown": False,
"on_failure_fail_dagrun": False,
"executor_config": {
"__type": "dict",
"__var": {
"pod_override": {
"__type": "k8s.V1Pod",
"__var": PodGenerator.serialize_pod(executor_config_pod),
}
},
},
"doc_md": "### Task Tutorial Documentation",
"_needs_expansion": False,
"weight_rule": "downstream",
"start_trigger_args": None,
"start_from_trigger": False,
"inlets": [
{
"__type": "asset",
"__var": {
"extra": {},
"group": "asset",
"name": "asset-1",
"uri": "asset-1",
},
},
{
"__type": "asset_alias",
"__var": {"group": "asset", "name": "alias-name"},
},
],
"outlets": [
{
"__type": "asset",
"__var": {
"extra": {},
"group": "asset",
"name": "asset-2",
"uri": "asset-2",
},
},
],
},
},
{
"__type": "operator",
"__var": {
"task_id": "custom_task",
"retries": 1,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"downstream_task_ids": [],
"_operator_extra_links": {"Google Custom": "_link_CustomOpLink"},
"ui_color": "#fff",
"ui_fgcolor": "#000",
"template_ext": [],
"template_fields": ["bash_command"],
"template_fields_renderers": {},
"task_type": "CustomOperator",
"_operator_name": "@custom",
"_task_module": "tests_common.test_utils.mock_operators",
"pool": "default_pool",
"is_setup": False,
"is_teardown": False,
"on_failure_fail_dagrun": False,
"_needs_expansion": False,
"weight_rule": "downstream",
"start_trigger_args": None,
"start_from_trigger": False,
},
},
],
"timezone": "UTC",
"access_control": {
"__type": "dict",
"__var": {
"test_role": {
"__type": "dict",
"__var": {
"DAGs": {
"__type": "set",
"__var": [
permissions.ACTION_CAN_READ,
permissions.ACTION_CAN_EDIT,
],
}
},
}
},
},
"edge_info": {},
"dag_dependencies": [
{
"dependency_id": '{"name": "asset-2", "uri": "asset-2"}',
"dependency_type": "asset",
"label": "asset-2",
"source": "simple_dag",
"target": "asset",
},
],
"params": [],
"tags": [],
},
}
# Test that v2 DAGs can be deserialized without conversion
dag = SerializedDAG.from_dict(v2)
expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth)
expected = SerializedDAG.from_dict(expected_sdag)
fields_to_verify = set(vars(expected).keys()) - {
"task_group", # Tested separately
"last_loaded", # Dynamically set to utcnow
}
for f in fields_to_verify:
dag_value = getattr(dag, f)
expected_value = getattr(expected, f)
assert dag_value == expected_value, (
f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} != V3={expected_value!r}"
)
for f in set(vars(expected.task_group).keys()) - {"dag"}:
dag_tg_value = getattr(dag.task_group, f)
expected_tg_value = getattr(expected.task_group, f)
assert dag_tg_value == expected_tg_value, (
f"V2 task_group field '{f}' differs: V2={dag_tg_value!r} != V3={expected_tg_value!r}"
)
def test_dag_schema_defaults_optimization():
"""Test that DAG fields matching schema defaults are excluded from serialization."""
# Create DAG with all schema default values
dag_with_defaults = DAG(
dag_id="test_defaults_dag",
start_date=datetime(2023, 1, 1),
# These should match schema defaults and be excluded
catchup=False,
fail_fast=False,
max_active_runs=16,
max_active_tasks=16,
max_consecutive_failed_dag_runs=0,
render_template_as_native_obj=False,
disable_bundle_versioning=False,
# These should be excluded as None
description=None,
doc_md=None,
)
# Serialize and check exclusions
serialized = SerializedDAG.to_dict(dag_with_defaults)
dag_data = serialized["dag"]
# Schema default fields should be excluded
for field in SerializedDAG.get_schema_defaults("dag").keys():
assert field not in dag_data, f"Schema default field '{field}' should be excluded"
# None fields should also be excluded
none_fields = ["description", "doc_md"]
for field in none_fields:
assert field not in dag_data, f"None field '{field}' should be excluded"
# Test deserialization restores defaults correctly
deserialized_dag = SerializedDAG.from_dict(serialized)
# Verify schema defaults are restored
assert deserialized_dag.catchup is False
assert deserialized_dag.fail_fast is False
assert deserialized_dag.max_active_runs == 16
assert deserialized_dag.max_active_tasks == 16
assert deserialized_dag.max_consecutive_failed_dag_runs == 0
assert deserialized_dag.render_template_as_native_obj is False
assert deserialized_dag.disable_bundle_versioning is False
# Test with non-default values (should be included)
dag_non_defaults = DAG(
dag_id="test_non_defaults_dag",
start_date=datetime(2023, 1, 1),
catchup=True, # Non-default
max_active_runs=32, # Non-default
description="Test description", # Non-None
)
serialized_non_defaults = SerializedDAG.to_dict(dag_non_defaults)
dag_non_defaults_data = serialized_non_defaults["dag"]
# Non-default values should be included
assert "catchup" in dag_non_defaults_data
assert dag_non_defaults_data["catchup"] is True
assert "max_active_runs" in dag_non_defaults_data
assert dag_non_defaults_data["max_active_runs"] == 32
assert "description" in dag_non_defaults_data
assert dag_non_defaults_data["description"] == "Test description"
def test_email_optimization_removes_email_attrs_when_email_empty():
"""Test that email_on_failure and email_on_retry are removed when email is empty."""
with DAG(dag_id="test_email_optimization") as dag:
BashOperator(
task_id="test_task",
bash_command="echo test",
email=None, # Empty email
email_on_failure=True, # This should be removed during serialization
email_on_retry=True, # This should be removed during serialization
)
serialized_dag = SerializedDAG.to_dict(dag)
task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
assert task_serialized is not None
assert "email_on_failure" not in task_serialized
assert "email_on_retry" not in task_serialized
# But they should be present when email is not empty
with DAG(dag_id="test_email_with_attrs") as dag_with_email:
BashOperator(
task_id="test_task_with_email",
bash_command="echo test",
email="test@example.com", # Non-empty email
email_on_failure=True,
email_on_retry=True,
)
serialized_dag_with_email = SerializedDAG.to_dict(dag_with_email)
task_with_email_serialized = serialized_dag_with_email["dag"]["tasks"][0]["__var"]
assert task_with_email_serialized is not None
# email_on_failure and email_on_retry SHOULD be in the serialized task
# since email is not empty
assert "email" in task_with_email_serialized
assert task_with_email_serialized["email"] == "test@example.com"
def dummy_callback():
pass
@pytest.mark.parametrize(
("callback_config", "expected_flags", "is_mapped"),
[
# Regular operator tests
(
{
"on_failure_callback": dummy_callback,
"on_retry_callback": [dummy_callback, dummy_callback],
"on_success_callback": dummy_callback,
},
{"has_on_failure_callback": True, "has_on_retry_callback": True, "has_on_success_callback": True},
False,
),
(
{}, # No callbacks
{
"has_on_failure_callback": False,
"has_on_retry_callback": False,
"has_on_success_callback": False,
},
False,
),
(
{"on_failure_callback": [], "on_success_callback": None}, # Empty callbacks
{"has_on_failure_callback": False, "has_on_success_callback": False},
False,
),
# Mapped operator tests
(
{"on_failure_callback": dummy_callback, "on_success_callback": [dummy_callback, dummy_callback]},
{"has_on_failure_callback": True, "has_on_success_callback": True},
True,
),
(
{}, # Mapped operator without callbacks
{"has_on_failure_callback": False, "has_on_success_callback": False},
True,
),
],
)
def test_task_callback_boolean_optimization(callback_config, expected_flags, is_mapped):
"""Test that task callbacks are optimized using has_on_*_callback boolean flags."""
dag = DAG(dag_id="test_callback_dag")
if is_mapped:
# Create mapped operator
task = BashOperator.partial(task_id="test_task", dag=dag, **callback_config).expand(
bash_command=["echo 1", "echo 2"]
)
serialized = BaseSerialization.serialize(task)
deserialized = BaseSerialization.deserialize(serialized)
# For mapped operators, check partial_kwargs
serialized_data = serialized.get("__var", {}).get("partial_kwargs", {})
# Test serialization
for flag, expected in expected_flags.items():
if expected:
assert flag in serialized_data
assert serialized_data[flag] is True
else:
assert serialized_data.get(flag, False) is False
# Test deserialized properties
for flag, expected in expected_flags.items():
assert getattr(deserialized, flag) is expected
else:
# Create regular operator
task = BashOperator(task_id="test_task", bash_command="echo test", dag=dag, **callback_config)
serialized = BaseSerialization.serialize(task)
deserialized = BaseSerialization.deserialize(serialized)
# For regular operators, check top-level
serialized_data = serialized.get("__var", {})
# Test serialization (only True values are stored)
for flag, expected in expected_flags.items():
if expected:
assert serialized_data.get(flag, False) is True
else:
assert serialized_data.get(flag, False) is False
# Test deserialized properties
for flag, expected in expected_flags.items():
assert getattr(deserialized, flag) is expected
@pytest.mark.parametrize(
"kwargs",
[
{"inlets": [Asset(uri="file://some.txt")]},
{"outlets": [Asset(uri="file://some.txt")]},
{"on_success_callback": lambda *args, **kwargs: None},
{"on_execute_callback": lambda *args, **kwargs: None},
],
)
def test_is_schedulable_task_empty_operator_evaluates_true(kwargs):
from airflow.providers.standard.operators.empty import EmptyOperator
dag = DAG(dag_id="test_dag")
task = EmptyOperator(task_id="empty_task", dag=dag, **kwargs)
serialized_task = BaseSerialization.deserialize(BaseSerialization.serialize(task))
assert TI.is_task_schedulable(serialized_task)
@pytest.mark.parametrize(
"kwargs",
[
{},
{"on_failure_callback": lambda *args, **kwargs: None},
{"on_skipped_callback": lambda *args, **kwargs: None},
{"on_retry_callback": lambda *args, **kwargs: None},
],
)
def test_is_schedulable_task_empty_operator_evaluates_false(kwargs):
from airflow.providers.standard.operators.empty import EmptyOperator
dag = DAG(dag_id="test_dag")
task = EmptyOperator(task_id="empty_task", dag=dag, **kwargs)
serialized_task = BaseSerialization.deserialize(BaseSerialization.serialize(task))
assert not TI.is_task_schedulable(serialized_task)
def test_is_schedulable_task_non_empty_operator():
dag = DAG(dag_id="test_dag")
regular_task = BashOperator(task_id="regular", bash_command="echo test", dag=dag)
mapped_task = BashOperator.partial(task_id="mapped", dag=dag).expand(bash_command=["echo 1"])
serialized_regular = BaseSerialization.deserialize(BaseSerialization.serialize(regular_task))
serialized_mapped = BaseSerialization.deserialize(BaseSerialization.serialize(mapped_task))
assert TI.is_task_schedulable(serialized_regular)
assert TI.is_task_schedulable(serialized_mapped)
def test_task_callback_properties_exist():
"""Test that all callback boolean properties exist on both regular and mapped operators."""
dag = DAG(dag_id="test_dag")
regular_task = BashOperator(task_id="regular", bash_command="echo test", dag=dag)
mapped_task = BashOperator.partial(task_id="mapped", dag=dag).expand(bash_command=["echo 1"])
callback_properties = [
"has_on_execute_callback",
"has_on_failure_callback",
"has_on_success_callback",
"has_on_retry_callback",
"has_on_skipped_callback",
]
for prop in callback_properties:
assert hasattr(regular_task, prop), f"Regular operator missing {prop}"
assert hasattr(mapped_task, prop), f"Mapped operator missing {prop}"
serialized_regular = BaseSerialization.deserialize(BaseSerialization.serialize(regular_task))
serialized_mapped = BaseSerialization.deserialize(BaseSerialization.serialize(mapped_task))
assert hasattr(serialized_regular, prop), f"Deserialized regular operator missing {prop}"
assert hasattr(serialized_mapped, prop), f"Deserialized mapped operator missing {prop}"
@pytest.mark.parametrize(
("old_callback_name", "new_callback_name"),
[
("on_execute_callback", "has_on_execute_callback"),
("on_failure_callback", "has_on_failure_callback"),
("on_success_callback", "has_on_success_callback"),
("on_retry_callback", "has_on_retry_callback"),
("on_skipped_callback", "has_on_skipped_callback"),
],
)
def test_task_callback_backward_compatibility(old_callback_name, new_callback_name):
"""Test that old serialized DAGs with on_*_callback keys are correctly converted to has_on_*_callback."""
old_serialized_task = {
"is_setup": False,
old_callback_name: [
" def dumm_callback(*args, **kwargs):\n # hello\n pass\n"
],
"is_teardown": False,
"task_type": "BaseOperator",
"pool": "default_pool",
"task_id": "simple_task",
"template_fields": [],
"on_failure_fail_dagrun": False,
"downstream_task_ids": [],
"template_ext": [],
"ui_fgcolor": "#000",
"weight_rule": "downstream",
"ui_color": "#fff",
"template_fields_renderers": {},
"_needs_expansion": False,
"start_from_trigger": False,
"_task_module": "airflow.sdk.bases.operator",
"start_trigger_args": None,
}
# Test deserialization converts old format to new format
deserialized_task = SerializedBaseOperator.deserialize_operator(old_serialized_task)
# Verify the new format is present and correct
assert hasattr(deserialized_task, new_callback_name)
assert getattr(deserialized_task, new_callback_name) is True
assert not hasattr(deserialized_task, old_callback_name)
# Test with empty/None callback (should convert to False)
old_serialized_task[old_callback_name] = None
deserialized_task_empty = SerializedBaseOperator.deserialize_operator(old_serialized_task)
assert getattr(deserialized_task_empty, new_callback_name) is False
def test_weight_rule_absolute_serialization_deserialization():
"""Test that weight_rule can be serialized and deserialized correctly."""
from airflow.sdk import task
with DAG("test_weight_rule_dag") as dag:
@task(weight_rule=WeightRule.ABSOLUTE)
def test_task():
return "test"
test_task()
serialized_dag = SerializedDAG.to_dict(dag)
assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] == "absolute"
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.task_dict["test_task"]
assert isinstance(deserialized_task.weight_rule, _AbsolutePriorityWeightStrategy)
class TestClientDefaultsGeneration:
"""Test client defaults generation functionality."""
def test_generate_client_defaults_basic(self):
"""Test basic client defaults generation."""
client_defaults = SerializedBaseOperator.generate_client_defaults()
assert isinstance(client_defaults, dict)
# Should only include serializable fields
serialized_fields = SerializedBaseOperator.get_serialized_fields()
for field in client_defaults:
assert field in serialized_fields, f"Field {field} not in serialized fields"
def test_generate_client_defaults_excludes_schema_defaults(self):
"""Test that client defaults excludes values that match schema defaults."""
client_defaults = SerializedBaseOperator.generate_client_defaults()
schema_defaults = SerializedBaseOperator.get_schema_defaults("operator")
# Check that values matching schema defaults are excluded
for field, value in client_defaults.items():
if field in schema_defaults:
assert value != schema_defaults[field], (
f"Field {field} has value {value!r} which matches schema default {schema_defaults[field]!r}"
)
def test_generate_client_defaults_excludes_none_and_empty(self):
"""Test that client defaults excludes None and empty collection values."""
client_defaults = SerializedBaseOperator.generate_client_defaults()
for field, value in client_defaults.items():
assert value is not None, f"Field {field} has None value"
assert value not in [[], (), set(), {}], f"Field {field} has empty collection value: {value!r}"
def test_generate_client_defaults_caching(self):
"""Test that client defaults generation is cached."""
# Clear cache first
SerializedBaseOperator.generate_client_defaults.cache_clear()
# First call
client_defaults_1 = SerializedBaseOperator.generate_client_defaults()
# Second call should return same object (cached)
client_defaults_2 = SerializedBaseOperator.generate_client_defaults()
assert client_defaults_1 is client_defaults_2, "Client defaults should be cached"
# Check cache info
cache_info = SerializedBaseOperator.generate_client_defaults.cache_info()
assert cache_info.hits >= 1, "Cache should have at least one hit"
def test_generate_client_defaults_only_operator_defaults_fields(self):
"""Test that only fields from OPERATOR_DEFAULTS are considered."""
client_defaults = SerializedBaseOperator.generate_client_defaults()
# All fields in client_defaults should originate from OPERATOR_DEFAULTS
for field in client_defaults:
assert field in OPERATOR_DEFAULTS, f"Field {field} not in OPERATOR_DEFAULTS"
class TestSchemaDefaults:
"""Test schema defaults functionality."""
def test_get_schema_defaults_operator(self):
"""Test getting schema defaults for operator type."""
schema_defaults = SerializedBaseOperator.get_schema_defaults("operator")
assert isinstance(schema_defaults, dict)
# Should contain expected operator defaults
expected_fields = [
"owner",
"trigger_rule",
"depends_on_past",
"retries",
"queue",
"pool",
"pool_slots",
"priority_weight",
"weight_rule",
"do_xcom_push",
]
for field in expected_fields:
assert field in schema_defaults, f"Expected field {field} not in schema defaults"
def test_get_schema_defaults_nonexistent_type(self):
"""Test getting schema defaults for nonexistent type."""
schema_defaults = SerializedBaseOperator.get_schema_defaults("nonexistent")
assert schema_defaults == {}
def test_get_operator_optional_fields_from_schema(self):
"""Test getting optional fields from schema."""
optional_fields = SerializedBaseOperator.get_operator_optional_fields_from_schema()
assert isinstance(optional_fields, set)
# Should not contain required fields
required_fields = {
"task_type",
"_task_module",
"task_id",
"ui_color",
"ui_fgcolor",
"template_fields",
}
overlap = optional_fields & required_fields
assert not overlap, f"Optional fields should not overlap with required fields: {overlap}"
class TestDeserializationDefaultsResolution:
"""Test defaults resolution during deserialization."""
def test_apply_defaults_to_encoded_op(self):
encoded_op = {"task_id": "test_task", "task_type": "BashOperator", "retries": 10}
client_defaults = {"tasks": {"retry_delay": 300.0, "retries": 2}} # Fix: wrap in "tasks"
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, client_defaults)
# Should merge in order: client_defaults, encoded_op
assert result["retry_delay"] == 300.0 # From client_defaults
assert result["task_id"] == "test_task" # From encoded_op (highest priority)
assert result["retries"] == 10
def test_apply_defaults_to_encoded_op_none_inputs(self):
"""Test defaults application with None inputs."""
encoded_op = {"task_id": "test_task"}
# With None client_defaults
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
assert result == encoded_op
@operator_defaults({"retries": 2})
def test_multiple_tasks_share_client_defaults(self):
"""Test that multiple tasks can share the same client_defaults when there are actually non-default values."""
with DAG(dag_id="test_dag") as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 2")
serialized = SerializedDAG.to_dict(dag)
# Should have one client_defaults section for all tasks
assert "client_defaults" in serialized
assert "tasks" in serialized["client_defaults"]
# All tasks should benefit from the same client_defaults
client_defaults = serialized["client_defaults"]["tasks"]
# Deserialize and check both tasks get the defaults
deserialized_dag = SerializedDAG.from_dict(serialized)
deserialized_task1 = deserialized_dag.get_task("task1")
deserialized_task2 = deserialized_dag.get_task("task2")
# Both tasks should have retries=2 from client_defaults
assert deserialized_task1.retries == 2
assert deserialized_task2.retries == 2
# Both tasks should have the same default values from client_defaults
for field in client_defaults:
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
value1 = getattr(deserialized_task1, field)
value2 = getattr(deserialized_task2, field)
assert value1 == value2, f"Tasks have different values for {field}: {value1} vs {value2}"
class TestMappedOperatorSerializationAndClientDefaults:
"""Test MappedOperator serialization with client defaults and callback properties."""
@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_application(self):
"""Test that client_defaults are correctly applied to MappedOperator during deserialization."""
with DAG(dag_id="test_mapped_dag") as dag:
# Create a mapped operator
BashOperator.partial(
task_id="mapped_task",
retries=5, # Override default
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
# Serialize the DAG
serialized_dag = SerializedDAG.to_dict(dag)
# Should have client_defaults section
assert "client_defaults" in serialized_dag
assert "tasks" in serialized_dag["client_defaults"]
# Deserialize and check that client_defaults are applied
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("mapped_task")
# Verify it's still a MappedOperator
from airflow.models.mappedoperator import MappedOperator as SchedulerMappedOperator
assert isinstance(deserialized_task, SchedulerMappedOperator)
# Check that client_defaults values are applied (e.g., retry_delay from client_defaults)
client_defaults = serialized_dag["client_defaults"]["tasks"]
if "retry_delay" in client_defaults:
# If retry_delay wasn't explicitly set, it should come from client_defaults
# Since we can't easily convert timedelta back, check the serialized format
assert hasattr(deserialized_task, "retry_delay")
# Explicit values should override client_defaults
assert deserialized_task.retries == 5 # Explicitly set value
@pytest.mark.parametrize(
("task_config", "dag_id", "task_id", "non_default_fields"),
[
# Test case 1: Size optimization with non-default values
pytest.param(
{"retries": 3}, # Only set non-default values
"test_mapped_size",
"mapped_size_test",
{"retries"},
id="non_default_fields",
),
# Test case 2: No duplication with default values
pytest.param(
{"retries": 0}, # This should match client_defaults and be optimized out
"test_no_duplication",
"mapped_task",
set(), # No fields should be non-default (all optimized out)
id="duplicate_fields",
),
# Test case 3: Mixed default/non-default values
pytest.param(
{"retries": 2, "max_active_tis_per_dag": 16}, # Mix of default and non-default
"test_mixed_optimization",
"mixed_task",
{"retries", "max_active_tis_per_dag"}, # Both should be preserved as they're non-default
id="test_mixed_optimization",
),
],
)
@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_optimization(
self, task_config, dag_id, task_id, non_default_fields
):
"""Test that MappedOperator serialization optimizes using client defaults."""
with DAG(dag_id=dag_id) as dag:
# Create mapped operator with specified configuration
BashOperator.partial(
task_id=task_id,
**task_config,
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
serialized_dag = SerializedDAG.to_dict(dag)
mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"]
assert mapped_task_serialized is not None
assert mapped_task_serialized.get("_is_mapped") is True
# Check optimization behavior
client_defaults = serialized_dag["client_defaults"]["tasks"]
partial_kwargs = mapped_task_serialized["partial_kwargs"]
# Check that all fields are optimized correctly
for field, default_value in client_defaults.items():
if field in non_default_fields:
# Non-default fields should be present in partial_kwargs
assert field in partial_kwargs, (
f"Field '{field}' should be in partial_kwargs as it's non-default"
)
# And have different values than defaults
assert partial_kwargs[field] != default_value, (
f"Field '{field}' should have non-default value"
)
else:
# Default fields should either not be present or have different values if present
if field in partial_kwargs:
assert partial_kwargs[field] != default_value, (
f"Field '{field}' with default value should be optimized out"
)
def test_mapped_operator_expand_input_preservation(self):
"""Test that expand_input is correctly preserved during serialization."""
with DAG(dag_id="test_expand_input"):
mapped_task = BashOperator.partial(task_id="test_expand").expand(
bash_command=["echo 1", "echo 2", "echo 3"], env={"VAR1": "value1", "VAR2": "value2"}
)
# Serialize and deserialize
serialized = BaseSerialization.serialize(mapped_task)
deserialized = BaseSerialization.deserialize(serialized)
# Check expand_input structure
assert hasattr(deserialized, "expand_input")
expand_input = deserialized.expand_input
# Verify the expand_input contains the expected data
assert hasattr(expand_input, "value")
expand_value = expand_input.value
assert "bash_command" in expand_value
assert "env" in expand_value
assert expand_value["bash_command"] == ["echo 1", "echo 2", "echo 3"]
assert expand_value["env"] == {"VAR1": "value1", "VAR2": "value2"}
@pytest.mark.parametrize(
("partial_kwargs_data", "expected_results"),
[
# Test case 1: Encoded format with client defaults
pytest.param(
{
"retry_delay": {"__type": "timedelta", "__var": 600.0},
"execution_timeout": {"__type": "timedelta", "__var": 1800.0},
"owner": "test_user",
},
{
"retry_delay": timedelta(seconds=600),
"execution_timeout": timedelta(seconds=1800),
"owner": "test_user",
},
id="encoded_with_client_defaults",
),
# Test case 2: Non-encoded format (optimized)
pytest.param(
{
"retry_delay": 600.0,
"execution_timeout": 1800.0,
},
{
"retry_delay": timedelta(seconds=600),
"execution_timeout": timedelta(seconds=1800),
},
id="non_encoded_optimized",
),
# Test case 3: Mixed format (some encoded, some not)
pytest.param(
{
"retry_delay": {"__type": "timedelta", "__var": 600.0}, # Encoded
"execution_timeout": 1800.0, # Non-encoded
},
{
"retry_delay": timedelta(seconds=600),
"execution_timeout": timedelta(seconds=1800),
},
id="mixed_encoded_non_encoded",
),
],
)
def test_partial_kwargs_deserialization_formats(self, partial_kwargs_data, expected_results):
"""Test deserialization of partial_kwargs in various formats (encoded, non-encoded, mixed)."""
result = SerializedBaseOperator._deserialize_partial_kwargs(partial_kwargs_data)
# Verify all expected results
for key, expected_value in expected_results.items():
assert key in result, f"Missing key '{key}' in result"
assert result[key] == expected_value, f"key '{key}': expected {expected_value}, got {result[key]}"
def test_partial_kwargs_end_to_end_deserialization(self):
"""Test end-to-end partial_kwargs deserialization with real MappedOperator."""
with DAG(dag_id="test_e2e_partial_kwargs") as dag:
BashOperator.partial(
task_id="mapped_task",
retry_delay=timedelta(seconds=600), # Non-default value
owner="custom_owner", # Non-default value
# retries not specified, should potentially get from client_defaults
).expand(bash_command=["echo 1", "echo 2"])
# Serialize and deserialize the DAG
serialized_dag = SerializedDAG.to_dict(dag)
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("mapped_task")
# Verify the task has correct values after round-trip
assert deserialized_task.retry_delay == timedelta(seconds=600)
assert deserialized_task.owner == "custom_owner"
# Verify partial_kwargs were deserialized correctly
assert "retry_delay" in deserialized_task.partial_kwargs
assert "owner" in deserialized_task.partial_kwargs
assert deserialized_task.partial_kwargs["retry_delay"] == timedelta(seconds=600)
assert deserialized_task.partial_kwargs["owner"] == "custom_owner"
@pytest.mark.parametrize(
("callbacks", "expected_has_flags", "absent_keys"),
[
pytest.param(
{
"on_failure_callback": lambda ctx: None,
"on_success_callback": lambda ctx: None,
"on_retry_callback": lambda ctx: None,
},
["has_on_failure_callback", "has_on_success_callback", "has_on_retry_callback"],
["on_failure_callback", "on_success_callback", "on_retry_callback"],
id="multiple_callbacks",
),
pytest.param(
{"on_failure_callback": lambda ctx: None},
["has_on_failure_callback"],
["on_failure_callback", "has_on_success_callback", "on_success_callback"],
id="single_callback",
),
pytest.param(
{"on_failure_callback": lambda ctx: None, "on_execute_callback": None},
["has_on_failure_callback"],
["on_failure_callback", "has_on_execute_callback", "on_execute_callback"],
id="callback_with_none",
),
pytest.param(
{},
[],
[
"has_on_execute_callback",
"has_on_failure_callback",
"has_on_success_callback",
"has_on_retry_callback",
"has_on_skipped_callback",
],
id="no_callbacks",
),
],
)
def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags, absent_keys):
"""Test callbacks in DAG default_args are serialized as boolean flags."""
default_args = {"owner": "test_owner", "retries": 2, **callbacks}
with DAG(dag_id="test_default_args_callbacks", default_args=default_args) as dag:
BashOperator(task_id="task1", bash_command="echo 1", dag=dag)
serialized_dag_dict = SerializedDAG.serialize_dag(dag)
default_args_dict = serialized_dag_dict["default_args"][Encoding.VAR]
for flag in expected_has_flags:
assert default_args_dict.get(flag) is True
for key in absent_keys:
assert key not in default_args_dict
assert default_args_dict["owner"] == "test_owner"
assert default_args_dict["retries"] == 2
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag_dict)
assert deserialized_dag.dag_id == "test_default_args_callbacks"