blob: 0f319a1afb363a3a6d76e1ffa754c5d0a82c6408 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Serialized DAG and BaseOperator."""
from __future__ import annotations
import collections.abc
import datetime
import enum
import inspect
import logging
import warnings
import weakref
from dataclasses import dataclass
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Union
import cattr
import lazy_object_proxy
import pendulum
from dateutil import relativedelta
from pendulum.tz.timezone import FixedTimezone, Timezone
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError
from airflow.jobs.job import Job
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, create_timetable
from airflow.models.dagrun import DagRun
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput, create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
from airflow.serialization.json_schema import Validator, load_dag_schema
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.dataset import DatasetPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
from airflow.timetables.base import Timetable
from airflow.utils.code_utils import get_python_source
from airflow.utils.docs import get_docs_url
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
if TYPE_CHECKING:
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
HAS_KUBERNETES: bool
try:
from kubernetes.client import models as k8s
from airflow.kubernetes.pod_generator import PodGenerator
except ImportError:
pass
log = logging.getLogger(__name__)
_OPERATOR_EXTRA_LINKS: set[str] = {
"airflow.operators.trigger_dagrun.TriggerDagRunLink",
"airflow.sensors.external_task.ExternalDagLink",
# Deprecated names, so that existing serialized dags load straight away.
"airflow.sensors.external_task.ExternalTaskSensorLink",
"airflow.operators.dagrun_operator.TriggerDagRunLink",
"airflow.sensors.external_task_sensor.ExternalTaskSensorLink",
}
@cache
def get_operator_extra_links() -> set[str]:
"""Get the operator extra links.
This includes both the built-in ones, and those come from the providers.
"""
_OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
return _OPERATOR_EXTRA_LINKS
@cache
def _get_default_mapped_partial() -> dict[str, Any]:
"""Get default partial kwargs in a mapped operator.
This is used to simplify a serialized mapped operator by excluding default
values supplied in the implementation from the serialized dict. Since those
are defaults, they are automatically supplied on de-serialization, so we
don't need to store them.
"""
# Use the private _expand() method to avoid the empty kwargs check.
default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
return BaseSerialization.serialize(default)[Encoding.VAR]
def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]:
"""Encode a relativedelta object."""
encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v}
if var.weekday and var.weekday.n:
# Every n'th Friday for example
encoded["weekday"] = [var.weekday.weekday, var.weekday.n]
elif var.weekday:
encoded["weekday"] = [var.weekday.weekday]
return encoded
def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta:
"""Dencode a relativedelta object."""
if "weekday" in var:
var["weekday"] = relativedelta.weekday(*var["weekday"]) # type: ignore
return relativedelta.relativedelta(**var)
def encode_timezone(var: Timezone) -> str | int:
"""Encode a Pendulum Timezone for serialization.
Airflow only supports timezone objects that implements Pendulum's Timezone
interface. We try to keep as much information as possible to make conversion
round-tripping possible (see ``decode_timezone``). We need to special-case
UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as
0 without the special case), but passing 0 into ``pendulum.timezone`` does
not give us UTC (but ``+00:00``).
"""
if isinstance(var, FixedTimezone):
if var.offset == 0:
return "UTC"
return var.offset
if isinstance(var, Timezone):
return var.name
raise ValueError(
f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}. "
f"See {get_docs_url('timezone.html#time-zone-aware-dags')}"
)
def decode_timezone(var: str | int) -> Timezone:
"""Decode a previously serialized Pendulum Timezone."""
return pendulum.tz.timezone(var)
def _get_registered_timetable(importable_string: str) -> type[Timetable] | None:
from airflow import plugins_manager
if importable_string.startswith("airflow.timetables."):
return import_string(importable_string)
plugins_manager.initialize_timetables_plugins()
if plugins_manager.timetable_classes:
return plugins_manager.timetable_classes.get(importable_string)
else:
return None
class _TimetableNotRegistered(ValueError):
def __init__(self, type_string: str) -> None:
self.type_string = type_string
def __str__(self) -> str:
return (
f"Timetable class {self.type_string!r} is not registered or "
"you have a top level database access that disrupted the session. "
"Please check the airflow best practices documentation."
)
def _encode_timetable(var: Timetable) -> dict[str, Any]:
"""Encode a timetable instance.
This delegates most of the serialization work to the type, so the behavior
can be completely controlled by a custom subclass.
"""
timetable_class = type(var)
importable_string = qualname(timetable_class)
if _get_registered_timetable(importable_string) is None:
raise _TimetableNotRegistered(importable_string)
return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()}
def _decode_timetable(var: dict[str, Any]) -> Timetable:
"""Decode a previously serialized timetable.
Most of the deserialization logic is delegated to the actual type, which
we import from string.
"""
importable_string = var[Encoding.TYPE]
timetable_class = _get_registered_timetable(importable_string)
if timetable_class is None:
raise _TimetableNotRegistered(importable_string)
return timetable_class.deserialize(var[Encoding.VAR])
class _XComRef(NamedTuple):
"""Used to store info needed to create XComArg.
We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when
deserializing an operator, we need to create something in its place, and
post-process it in ``deserialize_dag``.
"""
data: dict
def deref(self, dag: DAG) -> XComArg:
return deserialize_xcom_arg(self.data, dag)
# These two should be kept in sync. Note that these are intentionally not using
# the type declarations in expandinput.py so we always remember to update
# serialization logic when adding new ExpandInput variants. If you add things to
# the unions, be sure to update _ExpandInputRef to match.
_ExpandInputOriginalValue = Union[
# For .expand(**kwargs).
Mapping[str, Any],
# For expand_kwargs(arg).
XComArg,
Collection[Union[XComArg, Mapping[str, Any]]],
]
_ExpandInputSerializedValue = Union[
# For .expand(**kwargs).
Mapping[str, Any],
# For expand_kwargs(arg).
_XComRef,
Collection[Union[_XComRef, Mapping[str, Any]]],
]
class _ExpandInputRef(NamedTuple):
"""Used to store info needed to create a mapped operator's expand input.
This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
with ``_XComRef`` (see documentation on the latter type for reasoning).
"""
key: str
value: _ExpandInputSerializedValue
@classmethod
def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
"""Validate we've covered all ``ExpandInput.value`` types.
This function does not actually do anything, but is called during
serialization so Mypy will *statically* check we have handled all
possible ExpandInput cases.
"""
def deref(self, dag: DAG) -> ExpandInput:
"""De-reference into a concrete ExpandInput object.
If you add more cases here, be sure to update _ExpandInputOriginalValue
and _ExpandInputSerializedValue to match the logic.
"""
if isinstance(self.value, _XComRef):
value: Any = self.value.deref(dag)
elif isinstance(self.value, collections.abc.Mapping):
value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
else:
value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
return create_expand_input(self.key, value)
class BaseSerialization:
"""BaseSerialization provides utils for serialization."""
# JSON primitive types.
_primitive_types = (int, bool, float, str)
# Time types.
# datetime.date and datetime.time are converted to strings.
_datetime_types = (datetime.datetime,)
# Object types that are always excluded in serialization.
_excluded_types = (logging.Logger, Connection, type)
_json_schema: Validator | None = None
# Should the extra operator link be loaded via plugins when
# de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
# are not loaded to not run User code in Scheduler.
_load_operator_extra_links = True
_CONSTRUCTOR_PARAMS: dict[str, Parameter] = {}
SERIALIZER_VERSION = 1
@classmethod
def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str:
"""Stringifies DAGs and operators contained by var and returns a JSON string of var."""
return json.dumps(cls.to_dict(var), ensure_ascii=True)
@classmethod
def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict:
"""Stringifies DAGs and operators contained by var and returns a dict of var."""
# Don't call on this class directly - only SerializedDAG or
# SerializedBaseOperator should be used as the "entrypoint"
raise NotImplementedError()
@classmethod
def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple:
"""Deserializes json_str and reconstructs all DAGs and operators it contains."""
return cls.from_dict(json.loads(serialized_obj))
@classmethod
def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple:
"""Deserialize a dict of type decorators and reconstructs all DAGs and operators it contains."""
return cls.deserialize(serialized_obj)
@classmethod
def validate_schema(cls, serialized_obj: str | dict) -> None:
"""Validate serialized_obj satisfies JSON schema."""
if cls._json_schema is None:
raise AirflowException(f"JSON schema of {cls.__name__:s} is not set.")
if isinstance(serialized_obj, dict):
cls._json_schema.validate(serialized_obj)
elif isinstance(serialized_obj, str):
cls._json_schema.validate(json.loads(serialized_obj))
else:
raise TypeError("Invalid type: Only dict and str are supported.")
@staticmethod
def _encode(x: Any, type_: Any) -> dict[Encoding, Any]:
"""Encode data by a JSON dict."""
return {Encoding.VAR: x, Encoding.TYPE: type_}
@classmethod
def _is_primitive(cls, var: Any) -> bool:
"""Primitive types."""
return var is None or isinstance(var, cls._primitive_types)
@classmethod
def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
"""Types excluded from serialization."""
if var is None:
if not cls._is_constructor_param(attrname, instance):
# Any instance attribute, that is not a constructor argument, we exclude None as the default
return True
return cls._value_is_hardcoded_default(attrname, var, instance)
return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default(
attrname, var, instance
)
@classmethod
def serialize_to_json(
cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
) -> dict[str, Any]:
"""Serializes an object to JSON."""
serialized_object: dict[str, Any] = {}
keys_to_serialize = object_to_serialize.get_serialized_fields()
for key in keys_to_serialize:
# None is ignored in serialized form and is added back in deserialization.
value = getattr(object_to_serialize, key, None)
if cls._is_excluded(value, key, object_to_serialize):
continue
if key == "_operator_name":
# when operator_name matches task_type, we can remove
# it to reduce the JSON payload
task_type = getattr(object_to_serialize, "_task_type", None)
if value != task_type:
serialized_object[key] = cls.serialize(value)
elif key in decorated_fields:
serialized_object[key] = cls.serialize(value)
elif key == "timetable" and value is not None:
serialized_object[key] = _encode_timetable(value)
else:
value = cls.serialize(value)
if isinstance(value, dict) and Encoding.TYPE in value:
value = value[Encoding.VAR]
serialized_object[key] = value
return serialized_object
@classmethod
def serialize(
cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False
) -> Any: # Unfortunately there is no support for recursive types in mypy
"""Helper function of depth first search for serialization.
The serialization protocol is:
(1) keeping JSON supported types: primitives, dict, list;
(2) encoding other types as ``{TYPE: 'foo', VAR: 'bar'}``, the deserialization
step decode VAR according to TYPE;
(3) Operator has a special field CLASS to record the original class
name for displaying in UI.
:meta private:
"""
if use_pydantic_models and not _ENABLE_AIP_44:
raise RuntimeError(
"Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
"This parameter will be removed eventually when new serialization is used by AIP-44"
)
if cls._is_primitive(var):
# enum.IntEnum is an int instance, it causes json dumps error so we use its value.
if isinstance(var, enum.Enum):
return var.value
return var
elif isinstance(var, dict):
return cls._encode(
{
str(k): cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
for k, v in var.items()
},
type_=DAT.DICT,
)
elif isinstance(var, list):
return [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var]
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DAG):
return SerializedDAG.serialize_dag(var)
elif isinstance(var, Resources):
return var.to_dict()
elif isinstance(var, MappedOperator):
return SerializedBaseOperator.serialize_mapped_operator(var)
elif isinstance(var, BaseOperator):
return SerializedBaseOperator.serialize_operator(var)
elif isinstance(var, cls._datetime_types):
return cls._encode(var.timestamp(), type_=DAT.DATETIME)
elif isinstance(var, datetime.timedelta):
return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA)
elif isinstance(var, Timezone):
return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE)
elif isinstance(var, relativedelta.relativedelta):
return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA)
elif callable(var):
return str(get_python_source(var))
elif isinstance(var, set):
# FIXME: casts set to list in customized serialization in future.
try:
return cls._encode(
sorted(
cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var
),
type_=DAT.SET,
)
except TypeError:
return cls._encode(
[cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
type_=DAT.SET,
)
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
return cls._encode(
[cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
type_=DAT.TUPLE,
)
elif isinstance(var, TaskGroup):
return TaskGroupSerialization.serialize_task_group(var)
elif isinstance(var, Param):
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif use_pydantic_models and _ENABLE_AIP_44:
if isinstance(var, Job):
return cls._encode(JobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB)
elif isinstance(var, TaskInstance):
return cls._encode(TaskInstancePydantic.from_orm(var).dict(), type_=DAT.TASK_INSTANCE)
elif isinstance(var, DagRun):
return cls._encode(DagRunPydantic.from_orm(var).dict(), type_=DAT.DAG_RUN)
elif isinstance(var, Dataset):
return cls._encode(DatasetPydantic.from_orm(var).dict(), type_=DAT.DATA_SET)
else:
return cls.default_serialization(strict, var)
else:
return cls.default_serialization(strict, var)
@classmethod
def default_serialization(cls, strict, var) -> str:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
raise SerializationError("Encountered unexpected type")
return str(var)
@classmethod
def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
"""Helper function of depth first search for deserialization.
:meta private:
"""
# JSON primitives (except for dict) are not encoded.
if use_pydantic_models and not _ENABLE_AIP_44:
raise RuntimeError(
"Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
"This parameter will be removed eventually when new serialization is used by AIP-44"
)
if cls._is_primitive(encoded_var):
return encoded_var
elif isinstance(encoded_var, list):
return [cls.deserialize(v, use_pydantic_models) for v in encoded_var]
if not isinstance(encoded_var, dict):
raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}")
var = encoded_var[Encoding.VAR]
type_ = encoded_var[Encoding.TYPE]
if type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
return SerializedBaseOperator.deserialize_operator(var)
elif type_ == DAT.DATETIME:
return pendulum.from_timestamp(var)
elif type_ == DAT.POD:
if not _has_kubernetes():
raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
pod = PodGenerator.deserialize_model_dict(var)
return pod
elif type_ == DAT.TIMEDELTA:
return datetime.timedelta(seconds=var)
elif type_ == DAT.TIMEZONE:
return decode_timezone(var)
elif type_ == DAT.RELATIVEDELTA:
return decode_relativedelta(var)
elif type_ == DAT.SET:
return {cls.deserialize(v, use_pydantic_models) for v in var}
elif type_ == DAT.TUPLE:
return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
elif type_ == DAT.PARAM:
return cls._deserialize_param(var)
elif type_ == DAT.XCOM_REF:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif use_pydantic_models and _ENABLE_AIP_44:
if type_ == DAT.BASE_JOB:
return JobPydantic.parse_obj(var)
elif type_ == DAT.TASK_INSTANCE:
return TaskInstancePydantic.parse_obj(var)
elif type_ == DAT.DAG_RUN:
return DagRunPydantic.parse_obj(var)
elif type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")
_deserialize_datetime = pendulum.from_timestamp
_deserialize_timezone = pendulum.tz.timezone
@classmethod
def _deserialize_timedelta(cls, seconds: int) -> datetime.timedelta:
return datetime.timedelta(seconds=seconds)
@classmethod
def _is_constructor_param(cls, attrname: str, instance: Any) -> bool:
return attrname in cls._CONSTRUCTOR_PARAMS
@classmethod
def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -> bool:
"""
Return true if ``value`` is the hard-coded default for the given attribute.
This takes in to account cases where the ``max_active_tasks`` parameter is
stored in the ``_max_active_tasks`` attribute.
And by using `is` here only and not `==` this copes with the case a
user explicitly specifies an attribute with the same "value" as the
default. (This is because ``"default" is "default"`` will be False as
they are different strings with the same characters.)
Also returns True if the value is an empty list or empty dict. This is done
to account for the case where the default value of the field is None but has the
``field = field or {}`` set.
"""
if attrname in cls._CONSTRUCTOR_PARAMS and (
cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
):
return True
return False
@classmethod
def _serialize_param(cls, param: Param):
return dict(
__class=f"{param.__module__}.{param.__class__.__name__}",
default=cls.serialize(param.value),
description=cls.serialize(param.description),
schema=cls.serialize(param.schema),
)
@classmethod
def _deserialize_param(cls, param_dict: dict):
"""
Workaround to serialize Param on older versions.
In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
this class's ``serialize`` method. So before running through ``deserialize``,
we first verify that it's necessary to do.
"""
class_name = param_dict["__class"]
class_: type[Param] = import_string(class_name)
attrs = ("default", "description", "schema")
kwargs = {}
def is_serialized(val):
if isinstance(val, dict):
return Encoding.TYPE in val
if isinstance(val, list):
return all(isinstance(item, dict) and Encoding.TYPE in item for item in val)
return False
for attr in attrs:
if attr not in param_dict:
continue
val = param_dict[attr]
if is_serialized(val):
deserialized_val = cls.deserialize(param_dict[attr])
kwargs[attr] = deserialized_val
else:
kwargs[attr] = val
return class_(**kwargs)
@classmethod
def _serialize_params_dict(cls, params: ParamsDict | dict):
"""Serialize Params dict for a DAG or task."""
serialized_params = {}
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only.
try:
class_identity = f"{v.__module__}.{v.__class__.__name__}"
except AttributeError:
class_identity = ""
if class_identity == "airflow.models.param.Param":
serialized_params[k] = cls._serialize_param(v)
else:
raise ValueError(
f"Params to a DAG or a Task can be only of type airflow.models.param.Param, "
f"but param {k!r} is {v.__class__}"
)
return serialized_params
@classmethod
def _deserialize_params_dict(cls, encoded_params: dict) -> ParamsDict:
"""Deserialize a DAG's Params dict."""
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
op_params[k] = cls._deserialize_param(v)
else:
# Old style params, convert it
op_params[k] = Param(v)
return ParamsDict(op_params)
class DependencyDetector:
"""
Detects dependencies between DAGs.
:meta private:
"""
@staticmethod
def detect_task_dependencies(task: Operator) -> list[DagDependency]:
"""Detects dependencies caused by tasks."""
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.external_task import ExternalTaskSensor
deps = []
if isinstance(task, TriggerDagRunOperator):
deps.append(
DagDependency(
source=task.dag_id,
target=getattr(task, "trigger_dag_id"),
dependency_type="trigger",
dependency_id=task.task_id,
)
)
elif isinstance(task, ExternalTaskSensor):
deps.append(
DagDependency(
source=getattr(task, "external_dag_id"),
target=task.dag_id,
dependency_type="sensor",
dependency_id=task.task_id,
)
)
for obj in task.outlets or []:
if isinstance(obj, Dataset):
deps.append(
DagDependency(
source=task.dag_id,
target="dataset",
dependency_type="dataset",
dependency_id=obj.uri,
)
)
return deps
@staticmethod
def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
"""Detects dependencies set directly on the DAG object."""
if not dag:
return
for x in dag.dataset_triggers:
yield DagDependency(
source="dataset",
target=dag.dag_id,
dependency_type="dataset",
dependency_id=x.uri,
)
class SerializedBaseOperator(BaseOperator, BaseSerialization):
"""A JSON serializable representation of operator.
All operators are casted to SerializedBaseOperator after deserialization.
Class specific attributes used by UI are move to object attributes.
"""
_decorated_fields = {"executor_config"}
_CONSTRUCTOR_PARAMS = {
k: v.default
for k, v in signature(BaseOperator.__init__).parameters.items()
if v.default is not v.empty
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# task_type is used by UI to display the correct class type, because UI only
# receives BaseOperator from deserialized DAGs.
self._task_type = "BaseOperator"
# Move class attributes into object attributes.
self.ui_color = BaseOperator.ui_color
self.ui_fgcolor = BaseOperator.ui_fgcolor
self.template_ext = BaseOperator.template_ext
self.template_fields = BaseOperator.template_fields
self.operator_extra_links = BaseOperator.operator_extra_links
@property
def task_type(self) -> str:
# Overwrites task_type of BaseOperator to use _task_type instead of
# __class__.__name__.
return self._task_type
@task_type.setter
def task_type(self, task_type: str):
self._task_type = task_type
@property
def operator_name(self) -> str:
# Overwrites operator_name of BaseOperator to use _operator_name instead of
# __class__.operator_name.
return self._operator_name
@operator_name.setter
def operator_name(self, operator_name: str):
self._operator_name = operator_name
@classmethod
def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
# Handle expand_input and op_kwargs_expand_input.
expansion_kwargs = op._get_specified_expand_input()
if TYPE_CHECKING: # Let Mypy check the input type for us!
_ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
serialized_op[op._expand_input_attr] = {
"type": get_map_type_key(expansion_kwargs),
"value": cls.serialize(expansion_kwargs.value),
}
# Simplify partial_kwargs by comparing it to the most barebone object.
# Remove all entries that are simply default values.
serialized_partial = serialized_op["partial_kwargs"]
for k, default in _get_default_mapped_partial().items():
try:
v = serialized_partial[k]
except KeyError:
continue
if v == default:
del serialized_partial[k]
serialized_op["_is_mapped"] = True
return serialized_op
@classmethod
def serialize_operator(cls, op: BaseOperator) -> dict[str, Any]:
return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)
@classmethod
def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
"""Serializes operator into a JSON object."""
serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__)
serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
if op.operator_name != serialize_op["_task_type"]:
serialize_op["_operator_name"] = op.operator_name
# Used to determine if an Operator is inherited from EmptyOperator
serialize_op["_is_empty"] = op.inherits_from_empty_operator
if op.operator_extra_links:
serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
op.operator_extra_links
)
if include_deps:
serialize_op["deps"] = cls._serialize_deps(op.deps)
# Store all template_fields as they are if there are JSON Serializable
# If not, store them as strings
# And raise an exception if the field is not templateable
forbidden_fields = set(inspect.signature(BaseOperator.__init__).parameters.keys())
if op.template_fields:
for template_field in op.template_fields:
if template_field in forbidden_fields:
raise AirflowException(f"Cannot template BaseOperator fields: {template_field}")
value = getattr(op, template_field, None)
if not cls._is_excluded(value, template_field, op):
serialize_op[template_field] = serialize_template_field(value)
if op.params:
serialize_op["params"] = cls._serialize_params_dict(op.params)
return serialize_op
@classmethod
def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]:
from airflow import plugins_manager
plugins_manager.initialize_ti_deps_plugins()
if plugins_manager.registered_ti_dep_classes is None:
raise AirflowException("Can not load plugins")
deps = []
for dep in op_deps:
klass = type(dep)
module_name = klass.__module__
qualname = f"{module_name}.{klass.__name__}"
if (
not qualname.startswith("airflow.ti_deps.deps.")
and qualname not in plugins_manager.registered_ti_dep_classes
):
raise SerializationError(
f"Custom dep class {qualname} not serialized, please register it through plugins."
)
deps.append(qualname)
# deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
# and the same call may get different results.
# When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
return sorted(deps)
@classmethod
def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
if "label" not in encoded_op:
# Handle deserialization of old data before the introduction of TaskGroup
encoded_op["label"] = encoded_op["task_id"]
# Extra Operator Links defined in Plugins
op_extra_links_from_plugin = {}
if "_operator_name" not in encoded_op:
encoded_op["_operator_name"] = encoded_op["_task_type"]
# We don't want to load Extra Operator links in Scheduler
if cls._load_operator_extra_links:
from airflow import plugins_manager
plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can not load plugins")
for ope in plugins_manager.operator_extra_links:
for operator in ope.operators:
if (
operator.__name__ == encoded_op["_task_type"]
and operator.__module__ == encoded_op["_task_module"]
):
op_extra_links_from_plugin.update({ope.name: ope})
# If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
# set the Operator links attribute
# The case for "If OperatorLinks are defined in the operator that is being Serialized"
# is handled in the deserialization loop where it matches k == "_operator_extra_links"
if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
for k, v in encoded_op.items():
# Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed
if k == "_is_dummy":
k = "_is_empty"
if k in ("_outlets", "_inlets"):
# `_outlets` -> `outlets`
k = k[1:]
if k == "_downstream_task_ids":
# Upgrade from old format/name
k = "downstream_task_ids"
if k == "label":
# Label shouldn't be set anymore -- it's computed from task_id now
continue
elif k == "downstream_task_ids":
v = set(v)
elif k == "subdag":
v = SerializedDAG.deserialize_dag(v)
elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}:
v = cls._deserialize_timedelta(v)
elif k in encoded_op["template_fields"]:
pass
elif k == "resources":
v = Resources.from_dict(v)
elif k.endswith("_date"):
v = cls._deserialize_datetime(v)
elif k == "_operator_extra_links":
if cls._load_operator_extra_links:
op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
# If OperatorLinks with the same name exists, Links via Plugin have higher precedence
op_predefined_extra_links.update(op_extra_links_from_plugin)
else:
op_predefined_extra_links = {}
v = list(op_predefined_extra_links.values())
k = "operator_extra_links"
elif k == "deps":
v = cls._deserialize_deps(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
if op.params: # Merge existing params if needed.
v, new = op.params, v
v.update(new)
elif k == "partial_kwargs":
v = {arg: cls.deserialize(value) for arg, value in v.items()}
elif k in {"expand_input", "op_kwargs_expand_input"}:
v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls.deserialize(v)
elif k in ("outlets", "inlets"):
v = cls.deserialize(v)
# else use v as it is
setattr(op, k, v)
for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
# TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check
# could go away.
if not hasattr(op, k):
setattr(op, k, None)
# Set all the template_field to None that were not present in Serialized JSON
for field in op.template_fields:
if not hasattr(op, field):
setattr(op, field, None)
# Used to determine if an Operator is inherited from EmptyOperator
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
@classmethod
def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
"""Deserializes an operator from a JSON object."""
op: Operator
if encoded_op.get("_is_mapped", False):
# Most of these will be loaded later, these are just some stand-ins.
op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
try:
operator_name = encoded_op["_operator_name"]
except KeyError:
operator_name = encoded_op["_task_type"]
op = MappedOperator(
operator_class=op_data,
expand_input=EXPAND_INPUT_EMPTY,
partial_kwargs={},
task_id=encoded_op["task_id"],
params={},
deps=MappedOperator.deps_for(BaseOperator),
operator_extra_links=BaseOperator.operator_extra_links,
template_ext=BaseOperator.template_ext,
template_fields=BaseOperator.template_fields,
template_fields_renderers=BaseOperator.template_fields_renderers,
ui_color=BaseOperator.ui_color,
ui_fgcolor=BaseOperator.ui_fgcolor,
is_empty=False,
task_module=encoded_op["_task_module"],
task_type=encoded_op["_task_type"],
operator_name=operator_name,
dag=None,
task_group=None,
start_date=None,
end_date=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])
cls.populate_operator(op, encoded_op)
return op
@classmethod
def detect_dependencies(cls, op: Operator) -> set[DagDependency]:
"""Detects between DAG dependencies for the operator."""
def get_custom_dep() -> list[DagDependency]:
"""
If custom dependency detector is configured, use it.
TODO: Remove this logic in 3.0.
"""
custom_dependency_detector_cls = conf.getimport("scheduler", "dependency_detector", fallback=None)
if not (
custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
):
warnings.warn(
"Use of a custom dependency detector is deprecated. "
"Support will be removed in a future release.",
RemovedInAirflow3Warning,
)
dep = custom_dependency_detector_cls().detect_task_dependencies(op)
if type(dep) is DagDependency:
return [dep]
return []
dependency_detector = DependencyDetector()
deps = set(dependency_detector.detect_task_dependencies(op))
deps.update(get_custom_dep()) # todo: remove in 3.0
return deps
@classmethod
def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
if var is not None and op.has_dag() and attrname.endswith("_date"):
# If this date is the same as the matching field in the dag, then
# don't store it again at the task level.
dag_date = getattr(op.dag, attrname, None)
if var is dag_date or var == dag_date:
return True
return super()._is_excluded(var, attrname, op)
@classmethod
def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]:
from airflow import plugins_manager
plugins_manager.initialize_ti_deps_plugins()
if plugins_manager.registered_ti_dep_classes is None:
raise AirflowException("Can not load plugins")
instances = set()
for qn in set(deps):
if (
not qn.startswith("airflow.ti_deps.deps.")
and qn not in plugins_manager.registered_ti_dep_classes
):
raise SerializationError(
f"Custom dep class {qn} not deserialized, please register it through plugins."
)
try:
instances.add(import_string(qn)())
except ImportError:
log.warning("Error importing dep %r", qn, exc_info=True)
return instances
@classmethod
def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
"""
Deserialize Operator Links if the Classes are registered in Airflow Plugins.
Error is raised if the OperatorLink is not found in Plugins too.
:param encoded_op_links: Serialized Operator Link
:return: De-Serialized Operator Link
"""
from airflow import plugins_manager
plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.registered_operator_link_classes is None:
raise AirflowException("Can't load plugins")
op_predefined_extra_links = {}
for _operator_links_source in encoded_op_links:
# Get the key, value pair as Tuple where key is OperatorLink ClassName
# and value is the dictionary containing the arguments passed to the OperatorLink
#
# Example of a single iteration:
#
# _operator_links_source =
# {
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
# 'index': 0
# }
# },
#
# list(_operator_links_source.items()) =
# [
# (
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
# {'index': 0}
# )
# ]
#
# list(_operator_links_source.items())[0] =
# (
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
# {
# 'index': 0
# }
# )
_operator_link_class_path, data = list(_operator_links_source.items())[0]
if _operator_link_class_path in get_operator_extra_links():
single_op_link_class = import_string(_operator_link_class_path)
elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
single_op_link_class = plugins_manager.registered_operator_link_classes[
_operator_link_class_path
]
else:
log.error("Operator Link class %r not registered", _operator_link_class_path)
return {}
op_predefined_extra_link: BaseOperatorLink = cattr.structure(data, single_op_link_class)
op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})
return op_predefined_extra_links
@classmethod
def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
"""
Serialize Operator Links.
Store the import path of the OperatorLink and the arguments passed to it.
For example:
``[{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]``
:param operator_extra_links: Operator Link
:return: Serialized Operator Link
"""
serialize_operator_extra_links = []
for operator_extra_link in operator_extra_links:
op_link_arguments = cattr.unstructure(operator_extra_link)
if not isinstance(op_link_arguments, dict):
op_link_arguments = {}
module_path = (
f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
)
serialize_operator_extra_links.append({module_path: op_link_arguments})
return serialize_operator_extra_links
@classmethod
def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False) -> Any:
# the wonders of multiple inheritance BaseOperator defines an instance method
return BaseSerialization.serialize(var=var, strict=strict, use_pydantic_models=use_pydantic_models)
@classmethod
def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False) -> Any:
return BaseSerialization.deserialize(encoded_var=encoded_var, use_pydantic_models=use_pydantic_models)
class SerializedDAG(DAG, BaseSerialization):
"""
A JSON serializable representation of DAG.
A stringified DAG can only be used in the scope of scheduler and webserver, because fields
that are not serializable, such as functions and customer defined classes, are casted to
strings.
Compared with SimpleDAG: SerializedDAG contains all information for webserver.
Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
not pickle-able. SerializedDAG works for all DAGs.
"""
_decorated_fields = {"schedule_interval", "default_args", "_access_control"}
@staticmethod
def __get_constructor_defaults():
param_to_attr = {
"max_active_tasks": "_max_active_tasks",
"description": "_description",
"default_view": "_default_view",
"access_control": "_access_control",
}
return {
param_to_attr.get(k, k): v.default
for k, v in signature(DAG.__init__).parameters.items()
if v.default is not v.empty
}
_CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore
del __get_constructor_defaults
_json_schema = lazy_object_proxy.Proxy(load_dag_schema)
@classmethod
def serialize_dag(cls, dag: DAG) -> dict:
"""Serializes a DAG into a JSON object."""
try:
serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)
serialized_dag["_processor_dags_folder"] = DAGS_FOLDER
# If schedule_interval is backed by timetable, serialize only
# timetable; vice versa for a timetable backed by schedule_interval.
if dag.timetable.summary == dag.schedule_interval:
del serialized_dag["schedule_interval"]
else:
del serialized_dag["timetable"]
serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()]
dag_deps = {
dep
for task in dag.task_dict.values()
for dep in SerializedBaseOperator.detect_dependencies(task)
}
dag_deps.update(DependencyDetector.detect_dag_dependencies(dag))
serialized_dag["dag_dependencies"] = [x.__dict__ for x in dag_deps]
serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group)
# Edge info in the JSON exactly matches our internal structure
serialized_dag["edge_info"] = dag.edge_info
serialized_dag["params"] = cls._serialize_params_dict(dag.params)
# has_on_*_callback are only stored if the value is True, as the default is False
if dag.has_on_success_callback:
serialized_dag["has_on_success_callback"] = True
if dag.has_on_failure_callback:
serialized_dag["has_on_failure_callback"] = True
return serialized_dag
except SerializationError:
raise
except Exception as e:
raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}")
@classmethod
def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
"""Deserializes a DAG from a JSON object."""
dag = SerializedDAG(dag_id=encoded_dag["_dag_id"])
for k, v in encoded_dag.items():
if k == "_downstream_task_ids":
v = set(v)
elif k == "tasks":
SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
k = "task_dict"
elif k == "timezone":
v = cls._deserialize_timezone(v)
elif k == "dagrun_timeout":
v = cls._deserialize_timedelta(v)
elif k.endswith("_date"):
v = cls._deserialize_datetime(v)
elif k == "edge_info":
# Value structure matches exactly
pass
elif k == "timetable":
v = _decode_timetable(v)
elif k in cls._decorated_fields:
v = cls.deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k == "dataset_triggers":
v = cls.deserialize(v)
# else use v as it is
setattr(dag, k, v)
# A DAG is always serialized with only one of schedule_interval and
# timetable. This back-populates the other to ensure the two attributes
# line up correctly on the DAG instance.
if "timetable" in encoded_dag:
dag.schedule_interval = dag.timetable.summary
else:
dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)
# Set _task_group
if "_task_group" in encoded_dag:
dag._task_group = TaskGroupSerialization.deserialize_task_group(
encoded_dag["_task_group"],
None,
dag.task_dict,
dag,
)
else:
# This must be old data that had no task_group. Create a root TaskGroup and add
# all tasks to it.
dag._task_group = TaskGroup.create_root(dag)
for task in dag.tasks:
dag.task_group.add(task)
# Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
if "has_on_success_callback" in encoded_dag:
dag.has_on_success_callback = True
if "has_on_failure_callback" in encoded_dag:
dag.has_on_failure_callback = True
keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
for k in keys_to_set_none:
setattr(dag, k, None)
for task in dag.task_dict.values():
task.dag = dag
for date_attr in ["start_date", "end_date"]:
if getattr(task, date_attr) is None:
setattr(task, date_attr, getattr(dag, date_attr))
if task.subdag is not None:
setattr(task.subdag, "parent_dag", dag)
# Dereference expand_input and op_kwargs_expand_input.
for k in ("expand_input", "op_kwargs_expand_input"):
kwargs_ref = getattr(task, k, None)
if isinstance(kwargs_ref, _ExpandInputRef):
setattr(task, k, kwargs_ref.deref(dag))
for task_id in task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
return dag
@classmethod
def to_dict(cls, var: Any) -> dict:
"""Stringifies DAGs and operators contained by var and returns a dict of var."""
json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}
# Validate Serialized DAG with Json Schema. Raises Error if it mismatches
cls.validate_schema(json_dict)
return json_dict
@classmethod
def from_dict(cls, serialized_obj: dict) -> SerializedDAG:
"""Deserializes a python dict in to the DAG and operators it contains."""
ver = serialized_obj.get("__version", "<not present>")
if ver != cls.SERIALIZER_VERSION:
raise ValueError(f"Unsure how to deserialize version {ver!r}")
return cls.deserialize_dag(serialized_obj["dag"])
class TaskGroupSerialization(BaseSerialization):
"""JSON serializable representation of a task group."""
@classmethod
def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
"""Serializes TaskGroup into a JSON object."""
if not task_group:
return None
# task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set,
# when converting set to list, the order is uncertain.
# When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
encoded = {
"_group_id": task_group._group_id,
"prefix_group_id": task_group.prefix_group_id,
"tooltip": task_group.tooltip,
"ui_color": task_group.ui_color,
"ui_fgcolor": task_group.ui_fgcolor,
"children": {
label: child.serialize_for_task_group() for label, child in task_group.children.items()
},
"upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)),
"downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)),
"upstream_task_ids": cls.serialize(sorted(task_group.upstream_task_ids)),
"downstream_task_ids": cls.serialize(sorted(task_group.downstream_task_ids)),
}
if isinstance(task_group, MappedTaskGroup):
expand_input = task_group._expand_input
encoded["expand_input"] = {
"type": get_map_type_key(expand_input),
"value": cls.serialize(expand_input.value),
}
encoded["is_mapped"] = True
return encoded
@classmethod
def deserialize_task_group(
cls,
encoded_group: dict[str, Any],
parent_group: TaskGroup | None,
task_dict: dict[str, Operator],
dag: SerializedDAG,
) -> TaskGroup:
"""Deserializes a TaskGroup from a JSON object."""
group_id = cls.deserialize(encoded_group["_group_id"])
kwargs = {
key: cls.deserialize(encoded_group[key])
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
}
if not encoded_group.get("is_mapped"):
group = TaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
else:
xi = encoded_group["expand_input"]
group = MappedTaskGroup(
group_id=group_id,
parent_group=parent_group,
dag=dag,
expand_input=_ExpandInputRef(xi["type"], cls.deserialize(xi["value"])).deref(dag),
**kwargs,
)
def set_ref(task: Operator) -> Operator:
task.task_group = weakref.proxy(group)
return task
group.children = {
label: set_ref(task_dict[val])
if _type == DAT.OP
else cls.deserialize_task_group(val, group, task_dict, dag=dag)
for label, (_type, val) in encoded_group["children"].items()
}
group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"]))
group.downstream_group_ids.update(cls.deserialize(encoded_group["downstream_group_ids"]))
group.upstream_task_ids.update(cls.deserialize(encoded_group["upstream_task_ids"]))
group.downstream_task_ids.update(cls.deserialize(encoded_group["downstream_task_ids"]))
return group
@dataclass(frozen=True)
class DagDependency:
"""Dataclass for representing dependencies between DAGs.
These are calculated during serialization and attached to serialized DAGs.
"""
source: str
target: str
dependency_type: str
dependency_id: str | None = None
@property
def node_id(self):
"""Node ID for graph rendering."""
val = f"{self.dependency_type}"
if not self.dependency_type == "dataset":
val += f":{self.source}:{self.target}"
if self.dependency_id:
val += f":{self.dependency_id}"
return val
def _has_kubernetes() -> bool:
global HAS_KUBERNETES
if "HAS_KUBERNETES" in globals():
return HAS_KUBERNETES
# Loading kube modules is expensive, so delay it until the last moment
try:
from kubernetes.client import models as k8s
from airflow.kubernetes.pod_generator import PodGenerator
globals()["k8s"] = k8s
globals()["PodGenerator"] = PodGenerator
# isort: on
HAS_KUBERNETES = True
except ImportError:
HAS_KUBERNETES = False
return HAS_KUBERNETES