| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import collections |
| import collections.abc |
| import copy |
| import functools |
| import itertools |
| import logging |
| import os |
| import pathlib |
| import pickle |
| import re |
| import sys |
| import traceback |
| import warnings |
| import weakref |
| from collections import deque |
| from datetime import datetime, timedelta |
| from inspect import signature |
| from typing import ( |
| TYPE_CHECKING, |
| Any, |
| Callable, |
| Collection, |
| Deque, |
| Iterable, |
| Iterator, |
| List, |
| Sequence, |
| Union, |
| cast, |
| overload, |
| ) |
| from urllib.parse import urlsplit |
| |
| import jinja2 |
| import pendulum |
| from dateutil.relativedelta import relativedelta |
| from pendulum.tz.timezone import Timezone |
| from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, and_, case, func, not_, or_ |
| from sqlalchemy.ext.associationproxy import association_proxy |
| from sqlalchemy.orm import backref, joinedload, relationship |
| from sqlalchemy.orm.query import Query |
| from sqlalchemy.orm.session import Session |
| from sqlalchemy.sql import expression |
| |
| import airflow.templates |
| from airflow import settings, utils |
| from airflow.api_internal.internal_api_call import internal_api_call |
| from airflow.compat.functools import cached_property |
| from airflow.configuration import conf, secrets_backend_list |
| from airflow.exceptions import ( |
| AirflowDagInconsistent, |
| AirflowException, |
| AirflowSkipException, |
| DagInvalidTriggerRule, |
| DuplicateTaskIdFound, |
| RemovedInAirflow3Warning, |
| TaskNotFound, |
| ) |
| from airflow.jobs.job import run_job |
| from airflow.models.abstractoperator import AbstractOperator |
| from airflow.models.base import Base, StringID |
| from airflow.models.baseoperator import BaseOperator |
| from airflow.models.dagcode import DagCode |
| from airflow.models.dagpickle import DagPickle |
| from airflow.models.dagrun import DagRun |
| from airflow.models.operator import Operator |
| from airflow.models.param import DagParam, ParamsDict |
| from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances |
| from airflow.secrets.local_filesystem import LocalFilesystemBackend |
| from airflow.security import permissions |
| from airflow.stats import Stats |
| from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable |
| from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable |
| from airflow.timetables.simple import ( |
| ContinuousTimetable, |
| DatasetTriggeredTimetable, |
| NullTimetable, |
| OnceTimetable, |
| ) |
| from airflow.typing_compat import Literal |
| from airflow.utils import timezone |
| from airflow.utils.dag_cycle_tester import check_cycle |
| from airflow.utils.dates import cron_presets, date_range as utils_date_range |
| from airflow.utils.decorators import fixup_decorator_warning_stack |
| from airflow.utils.helpers import at_most_one, exactly_one, validate_key |
| from airflow.utils.log.logging_mixin import LoggingMixin |
| from airflow.utils.session import NEW_SESSION, provide_session |
| from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks |
| from airflow.utils.state import DagRunState, State, TaskInstanceState |
| from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType |
| |
| if TYPE_CHECKING: |
| from types import ModuleType |
| |
| from airflow.datasets import Dataset |
| from airflow.decorators import TaskDecoratorCollection |
| from airflow.models.dagbag import DagBag |
| from airflow.models.slamiss import SlaMiss |
| from airflow.utils.task_group import TaskGroup |
| |
| log = logging.getLogger(__name__) |
| |
| DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"] |
| ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] |
| |
| TAG_MAX_LEN = 100 |
| |
| DagStateChangeCallback = Callable[[Context], None] |
| ScheduleInterval = Union[None, str, timedelta, relativedelta] |
| |
| # FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval], |
| # but Mypy cannot handle that right now. Track progress of PEP 661 for progress. |
| # See also: https://discuss.python.org/t/9126/7 |
| ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] |
| ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]] |
| |
| SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] |
| |
| # Backward compatibility: If neither schedule_interval nor timetable is |
| # *provided by the user*, default to a one-day interval. |
| DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1) |
| |
| |
| class InconsistentDataInterval(AirflowException): |
| """Exception raised when a model populates data interval fields incorrectly. |
| |
| The data interval fields should either both be None (for runs scheduled |
| prior to AIP-39), or both be datetime (for runs scheduled after AIP-39 is |
| implemented). This is raised if exactly one of the fields is None. |
| """ |
| |
| _template = ( |
| "Inconsistent {cls}: {start[0]}={start[1]!r}, {end[0]}={end[1]!r}, " |
| "they must be either both None or both datetime" |
| ) |
| |
| def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> None: |
| self._class_name = type(instance).__name__ |
| self._start_field = (start_field_name, getattr(instance, start_field_name)) |
| self._end_field = (end_field_name, getattr(instance, end_field_name)) |
| |
| def __str__(self) -> str: |
| return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field) |
| |
| |
| def _get_model_data_interval( |
| instance: Any, |
| start_field_name: str, |
| end_field_name: str, |
| ) -> DataInterval | None: |
| start = timezone.coerce_datetime(getattr(instance, start_field_name)) |
| end = timezone.coerce_datetime(getattr(instance, end_field_name)) |
| if start is None: |
| if end is not None: |
| raise InconsistentDataInterval(instance, start_field_name, end_field_name) |
| return None |
| elif end is None: |
| raise InconsistentDataInterval(instance, start_field_name, end_field_name) |
| return DataInterval(start, end) |
| |
| |
| def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timetable: |
| """Create a Timetable instance from a ``schedule_interval`` argument.""" |
| if interval is NOTSET: |
| return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL) |
| if interval is None: |
| return NullTimetable() |
| if interval == "@once": |
| return OnceTimetable() |
| if interval == "@continuous": |
| return ContinuousTimetable() |
| if isinstance(interval, (timedelta, relativedelta)): |
| return DeltaDataIntervalTimetable(interval) |
| if isinstance(interval, str): |
| return CronDataIntervalTimetable(interval, timezone) |
| raise ValueError(f"{interval!r} is not a valid schedule_interval.") |
| |
| |
| def get_last_dagrun(dag_id, session, include_externally_triggered=False): |
| """ |
| Returns the last dag run for a dag, None if there was none. |
| Last dag run can be any type of run e.g. scheduled or backfilled. |
| Overridden DagRuns are ignored. |
| """ |
| DR = DagRun |
| query = session.query(DR).filter(DR.dag_id == dag_id) |
| if not include_externally_triggered: |
| query = query.filter(DR.external_trigger == expression.false()) |
| query = query.order_by(DR.execution_date.desc()) |
| return query.first() |
| |
| |
| def get_dataset_triggered_next_run_info( |
| dag_ids: list[str], *, session: Session |
| ) -> dict[str, dict[str, int | str]]: |
| """ |
| Given a list of dag_ids, get string representing how close any that are dataset triggered are |
| their next run, e.g. "1 of 2 datasets updated". |
| """ |
| from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel |
| |
| return { |
| x.dag_id: { |
| "uri": x.uri, |
| "ready": x.ready, |
| "total": x.total, |
| } |
| for x in session.query( |
| DagScheduleDatasetReference.dag_id, |
| # This is a dirty hack to workaround group by requiring an aggregate, since grouping by dataset |
| # is not what we want to do here...but it works |
| case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"), |
| func.count().label("total"), |
| func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), |
| ) |
| .join( |
| DDRQ, |
| and_( |
| DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id, |
| DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id, |
| ), |
| isouter=True, |
| ) |
| .join( |
| DatasetModel, |
| DatasetModel.id == DagScheduleDatasetReference.dataset_id, |
| ) |
| .group_by( |
| DagScheduleDatasetReference.dag_id, |
| ) |
| .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids)) |
| .all() |
| } |
| |
| |
| @functools.total_ordering |
| class DAG(LoggingMixin): |
| """ |
| A dag (directed acyclic graph) is a collection of tasks with directional |
| dependencies. A dag also has a schedule, a start date and an end date |
| (optional). For each schedule, (say daily or hourly), the DAG needs to run |
| each individual tasks as their dependencies are met. Certain tasks have |
| the property of depending on their own past, meaning that they can't run |
| until their previous schedule (and upstream tasks) are completed. |
| |
| DAGs essentially act as namespaces for tasks. A task_id can only be |
| added once to a DAG. |
| |
| Note that if you plan to use time zones all the dates provided should be pendulum |
| dates. See :ref:`timezone_aware_dags`. |
| |
| .. versionadded:: 2.4 |
| The *schedule* argument to specify either time-based scheduling logic |
| (timetable), or dataset-driven triggers. |
| |
| .. deprecated:: 2.4 |
| The arguments *schedule_interval* and *timetable*. Their functionalities |
| are merged into the new *schedule* argument. |
| |
| :param dag_id: The id of the DAG; must consist exclusively of alphanumeric |
| characters, dashes, dots and underscores (all ASCII) |
| :param description: The description for the DAG to e.g. be shown on the webserver |
| :param schedule: Defines the rules according to which DAG runs are scheduled. Can |
| accept cron string, timedelta object, Timetable, or list of Dataset objects. |
| See also :doc:`/howto/timetable`. |
| :param start_date: The timestamp from which the scheduler will |
| attempt to backfill |
| :param end_date: A date beyond which your DAG won't run, leave to None |
| for open-ended scheduling |
| :param template_searchpath: This list of folders (non-relative) |
| defines where jinja will look for your templates. Order matters. |
| Note that jinja/airflow includes the path of your DAG file by |
| default |
| :param template_undefined: Template undefined type. |
| :param user_defined_macros: a dictionary of macros that will be exposed |
| in your jinja templates. For example, passing ``dict(foo='bar')`` |
| to this argument allows you to ``{{ foo }}`` in all jinja |
| templates related to this DAG. Note that you can pass any |
| type of object here. |
| :param user_defined_filters: a dictionary of filters that will be exposed |
| in your jinja templates. For example, passing |
| ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows |
| you to ``{{ 'world' | hello }}`` in all jinja templates related to |
| this DAG. |
| :param default_args: A dictionary of default parameters to be used |
| as constructor keyword parameters when initialising operators. |
| Note that operators have the same hook, and precede those defined |
| here, meaning that if your dict contains `'depends_on_past': True` |
| here and `'depends_on_past': False` in the operator's call |
| `default_args`, the actual value will be `False`. |
| :param params: a dictionary of DAG level parameters that are made |
| accessible in templates, namespaced under `params`. These |
| params can be overridden at the task level. |
| :param max_active_tasks: the number of task instances allowed to run |
| concurrently |
| :param max_active_runs: maximum number of active DAG runs, beyond this |
| number of DAG runs in a running state, the scheduler won't create |
| new active DAG runs |
| :param dagrun_timeout: specify how long a DagRun should be up before |
| timing out / failing, so that new DagRuns can be created. |
| :param sla_miss_callback: specify a function or list of functions to call when reporting SLA |
| timeouts. See :ref:`sla_miss_callback<concepts:sla_miss_callback>` for |
| more information about the function signature and parameters that are |
| passed to the callback. |
| :param default_view: Specify DAG default view (grid, graph, duration, |
| gantt, landing_times), default grid |
| :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR |
| :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True |
| :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. |
| A context dictionary is passed as a single parameter to this function. |
| :param on_success_callback: Much like the ``on_failure_callback`` except |
| that it is executed when the dag succeeds. |
| :param access_control: Specify optional DAG-level actions, e.g., |
| "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" |
| :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. |
| If the dag exists already, this flag will be ignored. If this optional parameter |
| is not specified, the global config setting will be used. |
| :param jinja_environment_kwargs: additional configuration options to be passed to Jinja |
| ``Environment`` for template rendering |
| |
| **Example**: to avoid Jinja from removing a trailing newline from template strings :: |
| |
| DAG(dag_id='my-dag', |
| jinja_environment_kwargs={ |
| 'keep_trailing_newline': True, |
| # some other jinja2 Environment options here |
| } |
| ) |
| |
| **See**: `Jinja Environment documentation |
| <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_ |
| |
| :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` |
| to render templates as native Python types. If False, a Jinja |
| ``Environment`` is used to render templates as string values. |
| :param tags: List of tags to help filtering DAGs in the UI. |
| :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. |
| Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. |
| e.g: {"dag_owner": "https://airflow.apache.org/"} |
| :param auto_register: Automatically register this DAG when it is used in a ``with`` block |
| :param fail_stop: Fails currently running tasks when task in DAG fails. |
| **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success"). |
| An exception will be thrown if any task in a fail stop dag has a non default trigger rule. |
| """ |
| |
| _comps = { |
| "dag_id", |
| "task_ids", |
| "parent_dag", |
| "start_date", |
| "end_date", |
| "schedule_interval", |
| "fileloc", |
| "template_searchpath", |
| "last_loaded", |
| } |
| |
| __serialized_fields: frozenset[str] | None = None |
| |
| fileloc: str |
| """ |
| File path that needs to be imported to load this DAG or subdag. |
| |
| This may not be an actual file on disk in the case when this DAG is loaded |
| from a ZIP file or other DAG distribution format. |
| """ |
| |
| parent_dag: DAG | None = None # Gets set when DAGs are loaded |
| |
| # NOTE: When updating arguments here, please also keep arguments in @dag() |
| # below in sync. (Search for 'def dag(' in this file.) |
| def __init__( |
| self, |
| dag_id: str, |
| description: str | None = None, |
| schedule: ScheduleArg = NOTSET, |
| schedule_interval: ScheduleIntervalArg = NOTSET, |
| timetable: Timetable | None = None, |
| start_date: datetime | None = None, |
| end_date: datetime | None = None, |
| full_filepath: str | None = None, |
| template_searchpath: str | Iterable[str] | None = None, |
| template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, |
| user_defined_macros: dict | None = None, |
| user_defined_filters: dict | None = None, |
| default_args: dict | None = None, |
| concurrency: int | None = None, |
| max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), |
| max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), |
| dagrun_timeout: timedelta | None = None, |
| sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, |
| default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), |
| orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), |
| catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), |
| on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, |
| on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, |
| doc_md: str | None = None, |
| params: collections.abc.MutableMapping | None = None, |
| access_control: dict | None = None, |
| is_paused_upon_creation: bool | None = None, |
| jinja_environment_kwargs: dict | None = None, |
| render_template_as_native_obj: bool = False, |
| tags: list[str] | None = None, |
| owner_links: dict[str, str] | None = None, |
| auto_register: bool = True, |
| fail_stop: bool = False, |
| ): |
| from airflow.utils.task_group import TaskGroup |
| |
| if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): |
| raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") |
| |
| self.owner_links = owner_links if owner_links else {} |
| self.user_defined_macros = user_defined_macros |
| self.user_defined_filters = user_defined_filters |
| if default_args and not isinstance(default_args, dict): |
| raise TypeError("default_args must be a dict") |
| self.default_args = copy.deepcopy(default_args or {}) |
| params = params or {} |
| |
| # merging potentially conflicting default_args['params'] into params |
| if "params" in self.default_args: |
| params.update(self.default_args["params"]) |
| del self.default_args["params"] |
| |
| # check self.params and convert them into ParamsDict |
| self.params = ParamsDict(params) |
| |
| if full_filepath: |
| warnings.warn( |
| "Passing full_filepath to DAG() is deprecated and has no effect", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| |
| validate_key(dag_id) |
| |
| self._dag_id = dag_id |
| if concurrency: |
| # TODO: Remove in Airflow 3.0 |
| warnings.warn( |
| "The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| max_active_tasks = concurrency |
| self._max_active_tasks = max_active_tasks |
| self._pickle_id: int | None = None |
| |
| self._description = description |
| # set file location to caller source path |
| back = sys._getframe().f_back |
| self.fileloc = back.f_code.co_filename if back else "" |
| self.task_dict: dict[str, Operator] = {} |
| |
| # set timezone from start_date |
| tz = None |
| if start_date and start_date.tzinfo: |
| tzinfo = None if start_date.tzinfo else settings.TIMEZONE |
| tz = pendulum.instance(start_date, tz=tzinfo).timezone |
| elif "start_date" in self.default_args and self.default_args["start_date"]: |
| date = self.default_args["start_date"] |
| if not isinstance(date, datetime): |
| date = timezone.parse(date) |
| self.default_args["start_date"] = date |
| start_date = date |
| |
| tzinfo = None if date.tzinfo else settings.TIMEZONE |
| tz = pendulum.instance(date, tz=tzinfo).timezone |
| self.timezone = tz or settings.TIMEZONE |
| |
| # Apply the timezone we settled on to end_date if it wasn't supplied |
| if "end_date" in self.default_args and self.default_args["end_date"]: |
| if isinstance(self.default_args["end_date"], str): |
| self.default_args["end_date"] = timezone.parse( |
| self.default_args["end_date"], timezone=self.timezone |
| ) |
| |
| self.start_date = timezone.convert_to_utc(start_date) |
| self.end_date = timezone.convert_to_utc(end_date) |
| |
| # also convert tasks |
| if "start_date" in self.default_args: |
| self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) |
| if "end_date" in self.default_args: |
| self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) |
| |
| # sort out DAG's scheduling behavior |
| scheduling_args = [schedule_interval, timetable, schedule] |
| if not at_most_one(*scheduling_args): |
| raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.") |
| if schedule_interval is not NOTSET: |
| warnings.warn( |
| "Param `schedule_interval` is deprecated and will be removed in a future release. " |
| "Please use `schedule` instead. ", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| if timetable is not None: |
| warnings.warn( |
| "Param `timetable` is deprecated and will be removed in a future release. " |
| "Please use `schedule` instead. ", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| |
| self.timetable: Timetable |
| self.schedule_interval: ScheduleInterval |
| self.dataset_triggers: Collection[Dataset] = [] |
| |
| if isinstance(schedule, Collection) and not isinstance(schedule, str): |
| from airflow.datasets import Dataset |
| |
| if not all(isinstance(x, Dataset) for x in schedule): |
| raise ValueError("All elements in 'schedule' should be datasets") |
| self.dataset_triggers = list(schedule) |
| elif isinstance(schedule, Timetable): |
| timetable = schedule |
| elif schedule is not NOTSET: |
| schedule_interval = schedule |
| |
| if self.dataset_triggers: |
| self.timetable = DatasetTriggeredTimetable() |
| self.schedule_interval = self.timetable.summary |
| elif timetable: |
| self.timetable = timetable |
| self.schedule_interval = self.timetable.summary |
| else: |
| if isinstance(schedule_interval, ArgNotSet): |
| schedule_interval = DEFAULT_SCHEDULE_INTERVAL |
| self.schedule_interval = schedule_interval |
| self.timetable = create_timetable(schedule_interval, self.timezone) |
| |
| if isinstance(template_searchpath, str): |
| template_searchpath = [template_searchpath] |
| self.template_searchpath = template_searchpath |
| self.template_undefined = template_undefined |
| self.last_loaded = timezone.utcnow() |
| self.safe_dag_id = dag_id.replace(".", "__dot__") |
| self.max_active_runs = max_active_runs |
| if self.timetable.active_runs_limit is not None: |
| if self.timetable.active_runs_limit < self.max_active_runs: |
| raise AirflowException( |
| f"Invalid max_active_runs: {type(self.timetable)} " |
| f"requires max_active_runs <= {self.timetable.active_runs_limit}" |
| ) |
| self.dagrun_timeout = dagrun_timeout |
| self.sla_miss_callback = sla_miss_callback |
| if default_view in DEFAULT_VIEW_PRESETS: |
| self._default_view: str = default_view |
| elif default_view == "tree": |
| warnings.warn( |
| "`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| self._default_view = "grid" |
| else: |
| raise AirflowException( |
| f"Invalid values of dag.default_view: only support " |
| f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" |
| ) |
| if orientation in ORIENTATION_PRESETS: |
| self.orientation = orientation |
| else: |
| raise AirflowException( |
| f"Invalid values of dag.orientation: only support " |
| f"{ORIENTATION_PRESETS}, but get {orientation}" |
| ) |
| self.catchup = catchup |
| |
| self.partial = False |
| self.on_success_callback = on_success_callback |
| self.on_failure_callback = on_failure_callback |
| |
| # Keeps track of any extra edge metadata (sparse; will not contain all |
| # edges, so do not iterate over it for that). Outer key is upstream |
| # task ID, inner key is downstream task ID. |
| self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} |
| |
| # To keep it in parity with Serialized DAGs |
| # and identify if DAG has on_*_callback without actually storing them in Serialized JSON |
| self.has_on_success_callback = self.on_success_callback is not None |
| self.has_on_failure_callback = self.on_failure_callback is not None |
| |
| self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) |
| self.is_paused_upon_creation = is_paused_upon_creation |
| self.auto_register = auto_register |
| |
| self.fail_stop = fail_stop |
| |
| self.jinja_environment_kwargs = jinja_environment_kwargs |
| self.render_template_as_native_obj = render_template_as_native_obj |
| |
| self.doc_md = self.get_doc_md(doc_md) |
| |
| self.tags = tags or [] |
| self._task_group = TaskGroup.create_root(self) |
| self.validate_schedule_and_params() |
| wrong_links = dict(self.iter_invalid_owner_links()) |
| if wrong_links: |
| raise AirflowException( |
| "Wrong link format was used for the owner. Use a valid link \n" |
| f"Bad formatted links are: {wrong_links}" |
| ) |
| |
| # this will only be set at serialization time |
| # it's only use is for determining the relative |
| # fileloc based only on the serialize dag |
| self._processor_dags_folder = None |
| |
| def get_doc_md(self, doc_md: str | None) -> str | None: |
| if doc_md is None: |
| return doc_md |
| |
| env = self.get_template_env(force_sandboxed=True) |
| |
| if not doc_md.endswith(".md"): |
| template = jinja2.Template(doc_md) |
| else: |
| try: |
| template = env.get_template(doc_md) |
| except jinja2.exceptions.TemplateNotFound: |
| return f""" |
| # Templating Error! |
| Not able to find the template file: `{doc_md}`. |
| """ |
| |
| return template.render() |
| |
| def _check_schedule_interval_matches_timetable(self) -> bool: |
| """Check ``schedule_interval`` and ``timetable`` match. |
| |
| This is done as a part of the DAG validation done before it's bagged, to |
| guard against the DAG's ``timetable`` (or ``schedule_interval``) from |
| being changed after it's created, e.g. |
| |
| .. code-block:: python |
| |
| dag1 = DAG("d1", timetable=MyTimetable()) |
| dag1.schedule_interval = "@once" |
| |
| dag2 = DAG("d2", schedule="@once") |
| dag2.timetable = MyTimetable() |
| |
| Validation is done by creating a timetable and check its summary matches |
| ``schedule_interval``. The logic is not bullet-proof, especially if a |
| custom timetable does not provide a useful ``summary``. But this is the |
| best we can do. |
| """ |
| if self.schedule_interval == self.timetable.summary: |
| return True |
| try: |
| timetable = create_timetable(self.schedule_interval, self.timezone) |
| except ValueError: |
| return False |
| return timetable.summary == self.timetable.summary |
| |
| def validate(self): |
| """Validate the DAG has a coherent setup. |
| |
| This is called by the DAG bag before bagging the DAG. |
| """ |
| if not self._check_schedule_interval_matches_timetable(): |
| raise AirflowDagInconsistent( |
| f"inconsistent schedule: timetable {self.timetable.summary!r} " |
| f"does not match schedule_interval {self.schedule_interval!r}", |
| ) |
| self.params.validate() |
| self.timetable.validate() |
| |
| def __repr__(self): |
| return f"<DAG: {self.dag_id}>" |
| |
| def __eq__(self, other): |
| if type(self) == type(other): |
| # Use getattr() instead of __dict__ as __dict__ doesn't return |
| # correct values for properties. |
| return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) |
| return False |
| |
| def __ne__(self, other): |
| return not self == other |
| |
| def __lt__(self, other): |
| return self.dag_id < other.dag_id |
| |
| def __hash__(self): |
| hash_components = [type(self)] |
| for c in self._comps: |
| # task_ids returns a list and lists can't be hashed |
| if c == "task_ids": |
| val = tuple(self.task_dict.keys()) |
| else: |
| val = getattr(self, c, None) |
| try: |
| hash(val) |
| hash_components.append(val) |
| except TypeError: |
| hash_components.append(repr(val)) |
| return hash(tuple(hash_components)) |
| |
| # Context Manager ----------------------------------------------- |
| def __enter__(self): |
| DagContext.push_context_managed_dag(self) |
| return self |
| |
| def __exit__(self, _type, _value, _tb): |
| DagContext.pop_context_managed_dag() |
| |
| # /Context Manager ---------------------------------------------- |
| |
| @staticmethod |
| def _upgrade_outdated_dag_access_control(access_control=None): |
| """ |
| Looks for outdated dag level actions (can_dag_read and can_dag_edit) in DAG |
| access_controls (for example, {'role1': {'can_dag_read'}, 'role2': {'can_dag_read', 'can_dag_edit'}}) |
| and replaces them with updated actions (can_read and can_edit). |
| """ |
| if not access_control: |
| return None |
| new_perm_mapping = { |
| permissions.DEPRECATED_ACTION_CAN_DAG_READ: permissions.ACTION_CAN_READ, |
| permissions.DEPRECATED_ACTION_CAN_DAG_EDIT: permissions.ACTION_CAN_EDIT, |
| } |
| updated_access_control = {} |
| for role, perms in access_control.items(): |
| updated_access_control[role] = {new_perm_mapping.get(perm, perm) for perm in perms} |
| |
| if access_control != updated_access_control: |
| warnings.warn( |
| "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. " |
| "Please use 'can_read' and 'can_edit', respectively.", |
| RemovedInAirflow3Warning, |
| stacklevel=3, |
| ) |
| |
| return updated_access_control |
| |
| def date_range( |
| self, |
| start_date: pendulum.DateTime, |
| num: int | None = None, |
| end_date: datetime | None = None, |
| ) -> list[datetime]: |
| message = "`DAG.date_range()` is deprecated." |
| if num is not None: |
| warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", RemovedInAirflow3Warning) |
| return utils_date_range( |
| start_date=start_date, num=num, delta=self.normalized_schedule_interval |
| ) |
| message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead." |
| warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) |
| if end_date is None: |
| coerced_end_date = timezone.utcnow() |
| else: |
| coerced_end_date = end_date |
| it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False) |
| return [info.logical_date for info in it] |
| |
| def is_fixed_time_schedule(self): |
| warnings.warn( |
| "`DAG.is_fixed_time_schedule()` is deprecated.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| try: |
| return not self.timetable._should_fix_dst |
| except AttributeError: |
| return True |
| |
| def following_schedule(self, dttm): |
| """ |
| Calculates the following schedule for this dag in UTC. |
| |
| :param dttm: utc datetime |
| :return: utc datetime |
| """ |
| warnings.warn( |
| "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm)) |
| next_info = self.next_dagrun_info(data_interval, restricted=False) |
| if next_info is None: |
| return None |
| return next_info.data_interval.start |
| |
| def previous_schedule(self, dttm): |
| from airflow.timetables.interval import _DataIntervalTimetable |
| |
| warnings.warn( |
| "`DAG.previous_schedule()` is deprecated.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| if not isinstance(self.timetable, _DataIntervalTimetable): |
| return None |
| return self.timetable._get_prev(timezone.coerce_datetime(dttm)) |
| |
| def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None: |
| """Get the data interval of the next scheduled run. |
| |
| For compatibility, this method infers the data interval from the DAG's |
| schedule if the run does not have an explicit one set, which is possible |
| for runs created prior to AIP-39. |
| |
| This function is private to Airflow core and should not be depended on as a |
| part of the Python API. |
| |
| :meta private: |
| """ |
| if self.dag_id != dag_model.dag_id: |
| raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}") |
| if dag_model.next_dagrun is None: # Next run not scheduled. |
| return None |
| data_interval = dag_model.next_dagrun_data_interval |
| if data_interval is not None: |
| return data_interval |
| |
| # Compatibility: A run was scheduled without an explicit data interval. |
| # This means the run was scheduled before AIP-39 implementation. Try to |
| # infer from the logical date. |
| return self.infer_automated_data_interval(dag_model.next_dagrun) |
| |
| def get_run_data_interval(self, run: DagRun) -> DataInterval: |
| """Get the data interval of this run. |
| |
| For compatibility, this method infers the data interval from the DAG's |
| schedule if the run does not have an explicit one set, which is possible for |
| runs created prior to AIP-39. |
| |
| This function is private to Airflow core and should not be depended on as a |
| part of the Python API. |
| |
| :meta private: |
| """ |
| if run.dag_id is not None and run.dag_id != self.dag_id: |
| raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") |
| data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") |
| if data_interval is not None: |
| return data_interval |
| # Compatibility: runs created before AIP-39 implementation don't have an |
| # explicit data interval. Try to infer from the logical date. |
| return self.infer_automated_data_interval(run.execution_date) |
| |
| def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: |
| """Infer a data interval for a run against this DAG. |
| |
| This method is used to bridge runs created prior to AIP-39 |
| implementation, which do not have an explicit data interval. Therefore, |
| this method only considers ``schedule_interval`` values valid prior to |
| Airflow 2.2. |
| |
| DO NOT use this method is there is a known data interval. |
| """ |
| timetable_type = type(self.timetable) |
| if issubclass(timetable_type, (NullTimetable, OnceTimetable)): |
| return DataInterval.exact(timezone.coerce_datetime(logical_date)) |
| start = timezone.coerce_datetime(logical_date) |
| if issubclass(timetable_type, CronDataIntervalTimetable): |
| end = cast(CronDataIntervalTimetable, self.timetable)._get_next(start) |
| elif issubclass(timetable_type, DeltaDataIntervalTimetable): |
| end = cast(DeltaDataIntervalTimetable, self.timetable)._get_next(start) |
| else: |
| raise ValueError(f"Not a valid timetable: {self.timetable!r}") |
| return DataInterval(start, end) |
| |
| def next_dagrun_info( |
| self, |
| last_automated_dagrun: None | datetime | DataInterval, |
| *, |
| restricted: bool = True, |
| ) -> DagRunInfo | None: |
| """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. |
| |
| This calculates what time interval the next DagRun should operate on |
| (its execution date) and when it can be scheduled, according to the |
| dag's timetable, start_date, end_date, etc. This doesn't check max |
| active run or any other "max_active_tasks" type limits, but only |
| performs calculations based on the various date and interval fields of |
| this dag and its tasks. |
| |
| :param last_automated_dagrun: The ``max(execution_date)`` of |
| existing "automated" DagRuns for this dag (scheduled or backfill, |
| but not manual). |
| :param restricted: If set to *False* (default is *True*), ignore |
| ``start_date``, ``end_date``, and ``catchup`` specified on the DAG |
| or tasks. |
| :return: DagRunInfo of the next dagrun, or None if a dagrun is not |
| going to be scheduled. |
| """ |
| # Never schedule a subdag. It will be scheduled by its parent dag. |
| if self.is_subdag: |
| return None |
| |
| data_interval = None |
| if isinstance(last_automated_dagrun, datetime): |
| warnings.warn( |
| "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| data_interval = self.infer_automated_data_interval( |
| timezone.coerce_datetime(last_automated_dagrun) |
| ) |
| else: |
| data_interval = last_automated_dagrun |
| if restricted: |
| restriction = self._time_restriction |
| else: |
| restriction = TimeRestriction(earliest=None, latest=None, catchup=True) |
| try: |
| info = self.timetable.next_dagrun_info( |
| last_automated_data_interval=data_interval, |
| restriction=restriction, |
| ) |
| except Exception: |
| self.log.exception( |
| "Failed to fetch run info after data interval %s for DAG %r", |
| data_interval, |
| self.dag_id, |
| ) |
| info = None |
| return info |
| |
| def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None): |
| warnings.warn( |
| "`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| if date_last_automated_dagrun is None: |
| data_interval = None |
| else: |
| data_interval = self.infer_automated_data_interval(date_last_automated_dagrun) |
| info = self.next_dagrun_info(data_interval) |
| if info is None: |
| return None |
| return info.run_after |
| |
| @cached_property |
| def _time_restriction(self) -> TimeRestriction: |
| start_dates = [t.start_date for t in self.tasks if t.start_date] |
| if self.start_date is not None: |
| start_dates.append(self.start_date) |
| earliest = None |
| if start_dates: |
| earliest = timezone.coerce_datetime(min(start_dates)) |
| latest = self.end_date |
| end_dates = [t.end_date for t in self.tasks if t.end_date] |
| if len(end_dates) == len(self.tasks): # not exists null end_date |
| if self.end_date is not None: |
| end_dates.append(self.end_date) |
| if end_dates: |
| latest = timezone.coerce_datetime(max(end_dates)) |
| return TimeRestriction(earliest, latest, self.catchup) |
| |
| def iter_dagrun_infos_between( |
| self, |
| earliest: pendulum.DateTime | None, |
| latest: pendulum.DateTime, |
| *, |
| align: bool = True, |
| ) -> Iterable[DagRunInfo]: |
| """Yield DagRunInfo using this DAG's timetable between given interval. |
| |
| DagRunInfo instances yielded if their ``logical_date`` is not earlier |
| than ``earliest``, nor later than ``latest``. The instances are ordered |
| by their ``logical_date`` from earliest to latest. |
| |
| If ``align`` is ``False``, the first run will happen immediately on |
| ``earliest``, even if it does not fall on the logical timetable schedule. |
| The default is ``True``, but subdags will ignore this value and always |
| behave as if this is set to ``False`` for backward compatibility. |
| |
| Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If |
| ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be |
| ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` |
| if ``align=True``. |
| """ |
| if earliest is None: |
| earliest = self._time_restriction.earliest |
| if earliest is None: |
| raise ValueError("earliest was None and we had no value in time_restriction to fallback on") |
| earliest = timezone.coerce_datetime(earliest) |
| latest = timezone.coerce_datetime(latest) |
| |
| restriction = TimeRestriction(earliest, latest, catchup=True) |
| |
| # HACK: Sub-DAGs are currently scheduled differently. For example, say |
| # the schedule is @daily and start is 2021-06-03 22:16:00, a top-level |
| # DAG should be first scheduled to run on midnight 2021-06-04, but a |
| # sub-DAG should be first scheduled to run RIGHT NOW. We can change |
| # this, but since sub-DAGs are going away in 3.0 anyway, let's keep |
| # compatibility for now and remove this entirely later. |
| if self.is_subdag: |
| align = False |
| |
| try: |
| info = self.timetable.next_dagrun_info( |
| last_automated_data_interval=None, |
| restriction=restriction, |
| ) |
| except Exception: |
| self.log.exception( |
| "Failed to fetch run info after data interval %s for DAG %r", |
| None, |
| self.dag_id, |
| ) |
| info = None |
| |
| if info is None: |
| # No runs to be scheduled between the user-supplied timeframe. But |
| # if align=False, "invent" a data interval for the timeframe itself. |
| if not align: |
| yield DagRunInfo.interval(earliest, latest) |
| return |
| |
| # If align=False and earliest does not fall on the timetable's logical |
| # schedule, "invent" a data interval for it. |
| if not align and info.logical_date != earliest: |
| yield DagRunInfo.interval(earliest, info.data_interval.start) |
| |
| # Generate naturally according to schedule. |
| while info is not None: |
| yield info |
| try: |
| info = self.timetable.next_dagrun_info( |
| last_automated_data_interval=info.data_interval, |
| restriction=restriction, |
| ) |
| except Exception: |
| self.log.exception( |
| "Failed to fetch run info after data interval %s for DAG %r", |
| info.data_interval if info else "<NONE>", |
| self.dag_id, |
| ) |
| break |
| |
| def get_run_dates(self, start_date, end_date=None) -> list: |
| """ |
| Returns a list of dates between the interval received as parameter using this |
| dag's schedule interval. Returned dates can be used for execution dates. |
| |
| :param start_date: The start date of the interval. |
| :param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``. |
| :return: A list of dates within the interval following the dag's schedule. |
| """ |
| warnings.warn( |
| "`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| earliest = timezone.coerce_datetime(start_date) |
| if end_date is None: |
| latest = pendulum.now(timezone.utc) |
| else: |
| latest = timezone.coerce_datetime(end_date) |
| return [info.logical_date for info in self.iter_dagrun_infos_between(earliest, latest)] |
| |
| def normalize_schedule(self, dttm): |
| warnings.warn( |
| "`DAG.normalize_schedule()` is deprecated.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", RemovedInAirflow3Warning) |
| following = self.following_schedule(dttm) |
| if not following: # in case of @once |
| return dttm |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", RemovedInAirflow3Warning) |
| previous_of_following = self.previous_schedule(following) |
| if previous_of_following != dttm: |
| return following |
| return dttm |
| |
| @provide_session |
| def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): |
| return get_last_dagrun( |
| self.dag_id, session=session, include_externally_triggered=include_externally_triggered |
| ) |
| |
| @provide_session |
| def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool: |
| return ( |
| get_last_dagrun( |
| self.dag_id, session=session, include_externally_triggered=include_externally_triggered |
| ) |
| is not None |
| ) |
| |
| @property |
| def dag_id(self) -> str: |
| return self._dag_id |
| |
| @dag_id.setter |
| def dag_id(self, value: str) -> None: |
| self._dag_id = value |
| |
| @property |
| def is_subdag(self) -> bool: |
| return self.parent_dag is not None |
| |
| @property |
| def full_filepath(self) -> str: |
| """Full file path to the DAG. |
| |
| :meta private: |
| """ |
| warnings.warn( |
| "DAG.full_filepath is deprecated in favour of fileloc", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return self.fileloc |
| |
| @full_filepath.setter |
| def full_filepath(self, value) -> None: |
| warnings.warn( |
| "DAG.full_filepath is deprecated in favour of fileloc", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| self.fileloc = value |
| |
| @property |
| def concurrency(self) -> int: |
| # TODO: Remove in Airflow 3.0 |
| warnings.warn( |
| "The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return self._max_active_tasks |
| |
| @concurrency.setter |
| def concurrency(self, value: int): |
| self._max_active_tasks = value |
| |
| @property |
| def max_active_tasks(self) -> int: |
| return self._max_active_tasks |
| |
| @max_active_tasks.setter |
| def max_active_tasks(self, value: int): |
| self._max_active_tasks = value |
| |
| @property |
| def access_control(self): |
| return self._access_control |
| |
| @access_control.setter |
| def access_control(self, value): |
| self._access_control = DAG._upgrade_outdated_dag_access_control(value) |
| |
| @property |
| def description(self) -> str | None: |
| return self._description |
| |
| @property |
| def default_view(self) -> str: |
| return self._default_view |
| |
| @property |
| def pickle_id(self) -> int | None: |
| return self._pickle_id |
| |
| @pickle_id.setter |
| def pickle_id(self, value: int) -> None: |
| self._pickle_id = value |
| |
| def param(self, name: str, default: Any = NOTSET) -> DagParam: |
| """ |
| Return a DagParam object for current dag. |
| |
| :param name: dag parameter name. |
| :param default: fallback value for dag parameter. |
| :return: DagParam instance for specified name and current dag. |
| """ |
| return DagParam(current_dag=self, name=name, default=default) |
| |
| @property |
| def tasks(self) -> list[Operator]: |
| return list(self.task_dict.values()) |
| |
| @tasks.setter |
| def tasks(self, val): |
| raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") |
| |
| @property |
| def task_ids(self) -> list[str]: |
| return list(self.task_dict.keys()) |
| |
| @property |
| def teardowns(self) -> list[Operator]: |
| return [task for task in self.tasks if getattr(task, "is_teardown", None)] |
| |
| @property |
| def tasks_upstream_of_teardowns(self) -> list[Operator]: |
| upstream_tasks = [t.upstream_list for t in self.teardowns] |
| return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] |
| |
| @property |
| def task_group(self) -> TaskGroup: |
| return self._task_group |
| |
| @property |
| def filepath(self) -> str: |
| """Relative file path to the DAG. |
| |
| :meta private: |
| """ |
| warnings.warn( |
| "filepath is deprecated, use relative_fileloc instead", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return str(self.relative_fileloc) |
| |
| @property |
| def relative_fileloc(self) -> pathlib.Path: |
| """File location of the importable dag 'file' relative to the configured DAGs folder.""" |
| path = pathlib.Path(self.fileloc) |
| try: |
| rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER) |
| if rel_path == pathlib.Path("."): |
| return path |
| else: |
| return rel_path |
| except ValueError: |
| # Not relative to DAGS_FOLDER. |
| return path |
| |
| @property |
| def folder(self) -> str: |
| """Folder location of where the DAG object is instantiated.""" |
| return os.path.dirname(self.fileloc) |
| |
| @property |
| def owner(self) -> str: |
| """ |
| Return list of all owners found in DAG tasks. |
| |
| :return: Comma separated list of owners in DAG tasks |
| """ |
| return ", ".join({t.owner for t in self.tasks}) |
| |
| @property |
| def allow_future_exec_dates(self) -> bool: |
| return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_run |
| |
| @provide_session |
| def get_concurrency_reached(self, session=NEW_SESSION) -> bool: |
| """ |
| Returns a boolean indicating whether the max_active_tasks limit for this DAG |
| has been reached. |
| """ |
| TI = TaskInstance |
| qry = session.query(func.count(TI.task_id)).filter( |
| TI.dag_id == self.dag_id, |
| TI.state == State.RUNNING, |
| ) |
| return qry.scalar() >= self.max_active_tasks |
| |
| @property |
| def concurrency_reached(self): |
| """This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.""" |
| warnings.warn( |
| "This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return self.get_concurrency_reached() |
| |
| @provide_session |
| def get_is_active(self, session=NEW_SESSION) -> None: |
| """Returns a boolean indicating whether this DAG is active.""" |
| return session.query(DagModel.is_active).filter(DagModel.dag_id == self.dag_id).scalar() |
| |
| @provide_session |
| def get_is_paused(self, session=NEW_SESSION) -> None: |
| """Returns a boolean indicating whether this DAG is paused.""" |
| return session.query(DagModel.is_paused).filter(DagModel.dag_id == self.dag_id).scalar() |
| |
| @property |
| def is_paused(self): |
| """This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.""" |
| warnings.warn( |
| "This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return self.get_is_paused() |
| |
| @property |
| def normalized_schedule_interval(self) -> ScheduleInterval: |
| warnings.warn( |
| "DAG.normalized_schedule_interval() is deprecated.", |
| category=RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets: |
| _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval) |
| elif self.schedule_interval == "@once": |
| _schedule_interval = None |
| else: |
| _schedule_interval = self.schedule_interval |
| return _schedule_interval |
| |
| @provide_session |
| def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION): |
| """ |
| Triggers the appropriate callback depending on the value of success, namely the |
| on_failure_callback or on_success_callback. This method gets the context of a |
| single TaskInstance part of this DagRun and passes that to the callable along |
| with a 'reason', primarily to differentiate DagRun failures. |
| |
| .. note: The logs end up in |
| ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log`` |
| |
| :param dagrun: DagRun object |
| :param success: Flag to specify if failure or success callback should be called |
| :param reason: Completion reason |
| :param session: Database session |
| """ |
| callbacks = self.on_success_callback if success else self.on_failure_callback |
| if callbacks: |
| callbacks = callbacks if isinstance(callbacks, list) else [callbacks] |
| tis = dagrun.get_task_instances(session=session) |
| ti = tis[-1] # get first TaskInstance of DagRun |
| ti.task = self.get_task(ti.task_id) |
| context = ti.get_template_context(session=session) |
| context.update({"reason": reason}) |
| for callback in callbacks: |
| self.log.info("Executing dag callback function: %s", callback) |
| try: |
| callback(context) |
| except Exception: |
| self.log.exception("failed to invoke dag state update callback") |
| Stats.incr("dag.callback_exceptions", tags={"dag_id": dagrun.dag_id}) |
| |
| def get_active_runs(self): |
| """ |
| Returns a list of dag run execution dates currently running. |
| |
| :return: List of execution dates |
| """ |
| runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING) |
| |
| active_dates = [] |
| for run in runs: |
| active_dates.append(run.execution_date) |
| |
| return active_dates |
| |
| @provide_session |
| def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION): |
| """ |
| Returns the number of active "running" dag runs. |
| |
| :param external_trigger: True for externally triggered active dag runs |
| :param session: |
| :return: number greater than 0 for active dag runs |
| """ |
| # .count() is inefficient |
| query = session.query(func.count()).filter(DagRun.dag_id == self.dag_id) |
| if only_running: |
| query = query.filter(DagRun.state == State.RUNNING) |
| else: |
| query = query.filter(DagRun.state.in_({State.RUNNING, State.QUEUED})) |
| |
| if external_trigger is not None: |
| query = query.filter( |
| DagRun.external_trigger == (expression.true() if external_trigger else expression.false()) |
| ) |
| |
| return query.scalar() |
| |
| @provide_session |
| def get_dagrun( |
| self, |
| execution_date: datetime | None = None, |
| run_id: str | None = None, |
| session: Session = NEW_SESSION, |
| ): |
| """ |
| Returns the dag run for a given execution date or run_id if it exists, otherwise |
| none. |
| |
| :param execution_date: The execution date of the DagRun to find. |
| :param run_id: The run_id of the DagRun to find. |
| :param session: |
| :return: The DagRun if found, otherwise None. |
| """ |
| if not (execution_date or run_id): |
| raise TypeError("You must provide either the execution_date or the run_id") |
| query = session.query(DagRun) |
| if execution_date: |
| query = query.filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == execution_date) |
| if run_id: |
| query = query.filter(DagRun.dag_id == self.dag_id, DagRun.run_id == run_id) |
| return query.first() |
| |
| @provide_session |
| def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): |
| """ |
| Returns the list of dag runs between start_date (inclusive) and end_date (inclusive). |
| |
| :param start_date: The starting execution date of the DagRun to find. |
| :param end_date: The ending execution date of the DagRun to find. |
| :param session: |
| :return: The list of DagRuns found. |
| """ |
| dagruns = ( |
| session.query(DagRun) |
| .filter( |
| DagRun.dag_id == self.dag_id, |
| DagRun.execution_date >= start_date, |
| DagRun.execution_date <= end_date, |
| ) |
| .all() |
| ) |
| |
| return dagruns |
| |
| @provide_session |
| def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None: |
| """Returns the latest date for which at least one dag run exists.""" |
| return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar() |
| |
| @property |
| def latest_execution_date(self): |
| """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.""" |
| warnings.warn( |
| "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return self.get_latest_execution_date() |
| |
| @property |
| def subdags(self): |
| """Returns a list of the subdag objects associated to this DAG.""" |
| # Check SubDag for class but don't check class directly |
| from airflow.operators.subdag import SubDagOperator |
| |
| subdag_lst = [] |
| for task in self.tasks: |
| if ( |
| isinstance(task, SubDagOperator) |
| or |
| # TODO remove in Airflow 2.0 |
| type(task).__name__ == "SubDagOperator" |
| or task.task_type == "SubDagOperator" |
| ): |
| subdag_lst.append(task.subdag) |
| subdag_lst += task.subdag.subdags |
| return subdag_lst |
| |
| def resolve_template_files(self): |
| for t in self.tasks: |
| t.resolve_template_files() |
| |
| def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: |
| """Build a Jinja2 environment.""" |
| # Collect directories to search for template files |
| searchpath = [self.folder] |
| if self.template_searchpath: |
| searchpath += self.template_searchpath |
| |
| # Default values (for backward compatibility) |
| jinja_env_options = { |
| "loader": jinja2.FileSystemLoader(searchpath), |
| "undefined": self.template_undefined, |
| "extensions": ["jinja2.ext.do"], |
| "cache_size": 0, |
| } |
| if self.jinja_environment_kwargs: |
| jinja_env_options.update(self.jinja_environment_kwargs) |
| env: jinja2.Environment |
| if self.render_template_as_native_obj and not force_sandboxed: |
| env = airflow.templates.NativeEnvironment(**jinja_env_options) |
| else: |
| env = airflow.templates.SandboxedEnvironment(**jinja_env_options) |
| |
| # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. |
| # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals |
| if self.user_defined_macros: |
| env.globals.update(self.user_defined_macros) |
| if self.user_defined_filters: |
| env.filters.update(self.user_defined_filters) |
| |
| return env |
| |
| def set_dependency(self, upstream_task_id, downstream_task_id): |
| """ |
| Simple utility method to set dependency between two tasks that |
| already have been added to the DAG using add_task(). |
| """ |
| self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) |
| |
| @provide_session |
| def get_task_instances_before( |
| self, |
| base_date: datetime, |
| num: int, |
| *, |
| session: Session = NEW_SESSION, |
| ) -> list[TaskInstance]: |
| """Get ``num`` task instances before (including) ``base_date``. |
| |
| The returned list may contain exactly ``num`` task instances |
| corresponding to any DagRunType. It can have less if there are |
| less than ``num`` scheduled DAG runs before ``base_date``. |
| """ |
| execution_dates: list[Any] = ( |
| session.query(DagRun.execution_date) |
| .filter( |
| DagRun.dag_id == self.dag_id, |
| DagRun.execution_date <= base_date, |
| ) |
| .order_by(DagRun.execution_date.desc()) |
| .limit(num) |
| .all() |
| ) |
| |
| if len(execution_dates) == 0: |
| return self.get_task_instances(start_date=base_date, end_date=base_date, session=session) |
| |
| min_date: datetime | None = execution_dates[-1]._mapping.get( |
| "execution_date" |
| ) # getting the last value from the list |
| |
| return self.get_task_instances(start_date=min_date, end_date=base_date, session=session) |
| |
| @provide_session |
| def get_task_instances( |
| self, |
| start_date: datetime | None = None, |
| end_date: datetime | None = None, |
| state: list[TaskInstanceState] | None = None, |
| session: Session = NEW_SESSION, |
| ) -> list[TaskInstance]: |
| if not start_date: |
| start_date = (timezone.utcnow() - timedelta(30)).replace( |
| hour=0, minute=0, second=0, microsecond=0 |
| ) |
| |
| query = self._get_task_instances( |
| task_ids=None, |
| start_date=start_date, |
| end_date=end_date, |
| run_id=None, |
| state=state or (), |
| include_subdags=False, |
| include_parentdag=False, |
| include_dependent_dags=False, |
| exclude_task_ids=(), |
| session=session, |
| ) |
| return cast(Query, query).order_by(DagRun.execution_date).all() |
| |
| @overload |
| def _get_task_instances( |
| self, |
| *, |
| task_ids: Collection[str | tuple[str, int]] | None, |
| start_date: datetime | None, |
| end_date: datetime | None, |
| run_id: str | None, |
| state: TaskInstanceState | Sequence[TaskInstanceState], |
| include_subdags: bool, |
| include_parentdag: bool, |
| include_dependent_dags: bool, |
| exclude_task_ids: Collection[str | tuple[str, int]] | None, |
| session: Session, |
| dag_bag: DagBag | None = ..., |
| ) -> Iterable[TaskInstance]: |
| ... # pragma: no cover |
| |
| @overload |
| def _get_task_instances( |
| self, |
| *, |
| task_ids: Collection[str | tuple[str, int]] | None, |
| as_pk_tuple: Literal[True], |
| start_date: datetime | None, |
| end_date: datetime | None, |
| run_id: str | None, |
| state: TaskInstanceState | Sequence[TaskInstanceState], |
| include_subdags: bool, |
| include_parentdag: bool, |
| include_dependent_dags: bool, |
| exclude_task_ids: Collection[str | tuple[str, int]] | None, |
| session: Session, |
| dag_bag: DagBag | None = ..., |
| recursion_depth: int = ..., |
| max_recursion_depth: int = ..., |
| visited_external_tis: set[TaskInstanceKey] = ..., |
| ) -> set[TaskInstanceKey]: |
| ... # pragma: no cover |
| |
| def _get_task_instances( |
| self, |
| *, |
| task_ids: Collection[str | tuple[str, int]] | None, |
| as_pk_tuple: Literal[True, None] = None, |
| start_date: datetime | None, |
| end_date: datetime | None, |
| run_id: str | None, |
| state: TaskInstanceState | Sequence[TaskInstanceState], |
| include_subdags: bool, |
| include_parentdag: bool, |
| include_dependent_dags: bool, |
| exclude_task_ids: Collection[str | tuple[str, int]] | None, |
| session: Session, |
| dag_bag: DagBag | None = None, |
| recursion_depth: int = 0, |
| max_recursion_depth: int | None = None, |
| visited_external_tis: set[TaskInstanceKey] | None = None, |
| ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: |
| TI = TaskInstance |
| |
| # If we are looking at subdags/dependent dags we want to avoid UNION calls |
| # in SQL (it doesn't play nice with fields that have no equality operator, |
| # like JSON types), we instead build our result set separately. |
| # |
| # This will be empty if we are only looking at one dag, in which case |
| # we can return the filtered TI query object directly. |
| result: set[TaskInstanceKey] = set() |
| |
| # Do we want full objects, or just the primary columns? |
| if as_pk_tuple: |
| tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) |
| else: |
| tis = session.query(TaskInstance) |
| tis = tis.join(TaskInstance.dag_run) |
| |
| if include_subdags: |
| # Crafting the right filter for dag_id and task_ids combo |
| conditions = [] |
| for dag in self.subdags + [self]: |
| conditions.append( |
| (TaskInstance.dag_id == dag.dag_id) & TaskInstance.task_id.in_(dag.task_ids) |
| ) |
| tis = tis.filter(or_(*conditions)) |
| elif self.partial: |
| tis = tis.filter(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) |
| else: |
| tis = tis.filter(TaskInstance.dag_id == self.dag_id) |
| if run_id: |
| tis = tis.filter(TaskInstance.run_id == run_id) |
| if start_date: |
| tis = tis.filter(DagRun.execution_date >= start_date) |
| if task_ids is not None: |
| tis = tis.filter(TaskInstance.ti_selector_condition(task_ids)) |
| |
| # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC |
| if end_date or not self.allow_future_exec_dates: |
| end_date = end_date or timezone.utcnow() |
| tis = tis.filter(DagRun.execution_date <= end_date) |
| |
| if state: |
| if isinstance(state, (str, TaskInstanceState)): |
| tis = tis.filter(TaskInstance.state == state) |
| elif len(state) == 1: |
| tis = tis.filter(TaskInstance.state == state[0]) |
| else: |
| # this is required to deal with NULL values |
| if None in state: |
| if all(x is None for x in state): |
| tis = tis.filter(TaskInstance.state.is_(None)) |
| else: |
| not_none_state = [s for s in state if s] |
| tis = tis.filter( |
| or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) |
| ) |
| else: |
| tis = tis.filter(TaskInstance.state.in_(state)) |
| |
| # Next, get any of them from our parent DAG (if there is one) |
| if include_parentdag and self.parent_dag is not None: |
| |
| if visited_external_tis is None: |
| visited_external_tis = set() |
| |
| p_dag = self.parent_dag.partial_subset( |
| task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]), |
| include_upstream=False, |
| include_downstream=True, |
| ) |
| result.update( |
| p_dag._get_task_instances( |
| task_ids=task_ids, |
| start_date=start_date, |
| end_date=end_date, |
| run_id=None, |
| state=state, |
| include_subdags=include_subdags, |
| include_parentdag=False, |
| include_dependent_dags=include_dependent_dags, |
| as_pk_tuple=True, |
| exclude_task_ids=exclude_task_ids, |
| session=session, |
| dag_bag=dag_bag, |
| recursion_depth=recursion_depth, |
| max_recursion_depth=max_recursion_depth, |
| visited_external_tis=visited_external_tis, |
| ) |
| ) |
| |
| if include_dependent_dags: |
| # Recursively find external tasks indicated by ExternalTaskMarker |
| from airflow.sensors.external_task import ExternalTaskMarker |
| |
| query = tis |
| if as_pk_tuple: |
| condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols in tis.all()) |
| if condition is not None: |
| query = session.query(TI).filter(condition) |
| |
| if visited_external_tis is None: |
| visited_external_tis = set() |
| |
| for ti in query.filter(TI.operator == ExternalTaskMarker.__name__): |
| ti_key = ti.key.primary |
| if ti_key in visited_external_tis: |
| continue |
| |
| visited_external_tis.add(ti_key) |
| |
| task: ExternalTaskMarker = cast(ExternalTaskMarker, copy.copy(self.get_task(ti.task_id))) |
| ti.task = task |
| |
| if max_recursion_depth is None: |
| # Maximum recursion depth allowed is the recursion_depth of the first |
| # ExternalTaskMarker in the tasks to be visited. |
| max_recursion_depth = task.recursion_depth |
| |
| if recursion_depth + 1 > max_recursion_depth: |
| # Prevent cycles or accidents. |
| raise AirflowException( |
| f"Maximum recursion depth {max_recursion_depth} reached for " |
| f"{ExternalTaskMarker.__name__} {ti.task_id}. " |
| f"Attempted to clear too many tasks or there may be a cyclic dependency." |
| ) |
| ti.render_templates() |
| external_tis = ( |
| session.query(TI) |
| .join(TI.dag_run) |
| .filter( |
| TI.dag_id == task.external_dag_id, |
| TI.task_id == task.external_task_id, |
| DagRun.execution_date == pendulum.parse(task.execution_date), |
| ) |
| ) |
| |
| for tii in external_tis: |
| if not dag_bag: |
| from airflow.models.dagbag import DagBag |
| |
| dag_bag = DagBag(read_dags_from_db=True) |
| external_dag = dag_bag.get_dag(tii.dag_id, session=session) |
| if not external_dag: |
| raise AirflowException(f"Could not find dag {tii.dag_id}") |
| downstream = external_dag.partial_subset( |
| task_ids_or_regex=[tii.task_id], |
| include_upstream=False, |
| include_downstream=True, |
| ) |
| result.update( |
| downstream._get_task_instances( |
| task_ids=None, |
| run_id=tii.run_id, |
| start_date=None, |
| end_date=None, |
| state=state, |
| include_subdags=include_subdags, |
| include_dependent_dags=include_dependent_dags, |
| include_parentdag=False, |
| as_pk_tuple=True, |
| exclude_task_ids=exclude_task_ids, |
| dag_bag=dag_bag, |
| session=session, |
| recursion_depth=recursion_depth + 1, |
| max_recursion_depth=max_recursion_depth, |
| visited_external_tis=visited_external_tis, |
| ) |
| ) |
| |
| if result or as_pk_tuple: |
| # Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.) |
| if as_pk_tuple: |
| result.update(TaskInstanceKey(**cols._mapping) for cols in tis.all()) |
| else: |
| result.update(ti.key for ti in tis) |
| |
| if exclude_task_ids is not None: |
| result = { |
| task |
| for task in result |
| if task.task_id not in exclude_task_ids |
| and (task.task_id, task.map_index) not in exclude_task_ids |
| } |
| |
| if as_pk_tuple: |
| return result |
| if result: |
| # We've been asked for objects, lets combine it all back in to a result set |
| ti_filters = TI.filter_for_tis(result) |
| if ti_filters is not None: |
| tis = session.query(TI).filter(ti_filters) |
| elif exclude_task_ids is None: |
| pass # Disable filter if not set. |
| elif isinstance(next(iter(exclude_task_ids), None), str): |
| tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) |
| else: |
| tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) |
| |
| return tis |
| |
| @provide_session |
| def set_task_instance_state( |
| self, |
| *, |
| task_id: str, |
| map_indexes: Collection[int] | None = None, |
| execution_date: datetime | None = None, |
| run_id: str | None = None, |
| state: TaskInstanceState, |
| upstream: bool = False, |
| downstream: bool = False, |
| future: bool = False, |
| past: bool = False, |
| commit: bool = True, |
| session=NEW_SESSION, |
| ) -> list[TaskInstance]: |
| """ |
| Set the state of a TaskInstance to the given state, and clear its downstream tasks that are |
| in failed or upstream_failed state. |
| |
| :param task_id: Task ID of the TaskInstance |
| :param map_indexes: Only set TaskInstance if its map_index matches. |
| If None (default), all mapped TaskInstances of the task are set. |
| :param execution_date: Execution date of the TaskInstance |
| :param run_id: The run_id of the TaskInstance |
| :param state: State to set the TaskInstance to |
| :param upstream: Include all upstream tasks of the given task_id |
| :param downstream: Include all downstream tasks of the given task_id |
| :param future: Include all future TaskInstances of the given task_id |
| :param commit: Commit changes |
| :param past: Include all past TaskInstances of the given task_id |
| """ |
| from airflow.api.common.mark_tasks import set_state |
| |
| if not exactly_one(execution_date, run_id): |
| raise ValueError("Exactly one of execution_date or run_id must be provided") |
| |
| task = self.get_task(task_id) |
| task.dag = self |
| |
| tasks_to_set_state: list[Operator | tuple[Operator, int]] |
| if map_indexes is None: |
| tasks_to_set_state = [task] |
| else: |
| tasks_to_set_state = [(task, map_index) for map_index in map_indexes] |
| |
| altered = set_state( |
| tasks=tasks_to_set_state, |
| execution_date=execution_date, |
| run_id=run_id, |
| upstream=upstream, |
| downstream=downstream, |
| future=future, |
| past=past, |
| state=state, |
| commit=commit, |
| session=session, |
| ) |
| |
| if not commit: |
| return altered |
| |
| # Clear downstream tasks that are in failed/upstream_failed state to resume them. |
| # Flush the session so that the tasks marked success are reflected in the db. |
| session.flush() |
| subdag = self.partial_subset( |
| task_ids_or_regex={task_id}, |
| include_downstream=True, |
| include_upstream=False, |
| ) |
| |
| if execution_date is None: |
| dag_run = ( |
| session.query(DagRun).filter(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id).one() |
| ) # Raises an error if not found |
| resolve_execution_date = dag_run.execution_date |
| else: |
| resolve_execution_date = execution_date |
| |
| end_date = resolve_execution_date if not future else None |
| start_date = resolve_execution_date if not past else None |
| |
| subdag.clear( |
| start_date=start_date, |
| end_date=end_date, |
| include_subdags=True, |
| include_parentdag=True, |
| only_failed=True, |
| session=session, |
| # Exclude the task itself from being cleared |
| exclude_task_ids=frozenset({task_id}), |
| ) |
| |
| return altered |
| |
| @provide_session |
| def set_task_group_state( |
| self, |
| *, |
| group_id: str, |
| execution_date: datetime | None = None, |
| run_id: str | None = None, |
| state: TaskInstanceState, |
| upstream: bool = False, |
| downstream: bool = False, |
| future: bool = False, |
| past: bool = False, |
| commit: bool = True, |
| session: Session = NEW_SESSION, |
| ) -> list[TaskInstance]: |
| """ |
| Set the state of the TaskGroup to the given state, and clear its downstream tasks that are |
| in failed or upstream_failed state. |
| |
| :param group_id: The group_id of the TaskGroup |
| :param execution_date: Execution date of the TaskInstance |
| :param run_id: The run_id of the TaskInstance |
| :param state: State to set the TaskInstance to |
| :param upstream: Include all upstream tasks of the given task_id |
| :param downstream: Include all downstream tasks of the given task_id |
| :param future: Include all future TaskInstances of the given task_id |
| :param commit: Commit changes |
| :param past: Include all past TaskInstances of the given task_id |
| :param session: new session |
| """ |
| from airflow.api.common.mark_tasks import set_state |
| |
| if not exactly_one(execution_date, run_id): |
| raise ValueError("Exactly one of execution_date or run_id must be provided") |
| |
| tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = [] |
| task_ids: list[str] = [] |
| locked_dag_run_ids: list[int] = [] |
| |
| if execution_date is None: |
| dag_run = ( |
| session.query(DagRun).filter(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id).one() |
| ) # Raises an error if not found |
| resolve_execution_date = dag_run.execution_date |
| else: |
| resolve_execution_date = execution_date |
| |
| end_date = resolve_execution_date if not future else None |
| start_date = resolve_execution_date if not past else None |
| |
| task_group_dict = self.task_group.get_task_group_dict() |
| task_group = task_group_dict.get(group_id) |
| if task_group is None: |
| raise ValueError("TaskGroup {group_id} could not be found") |
| tasks_to_set_state = [task for task in task_group.iter_tasks() if isinstance(task, BaseOperator)] |
| task_ids = [task.task_id for task in task_group.iter_tasks()] |
| dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id == self.dag_id).with_for_update() |
| |
| if start_date is None and end_date is None: |
| dag_runs_query = dag_runs_query.filter(DagRun.execution_date == start_date) |
| else: |
| if start_date is not None: |
| dag_runs_query = dag_runs_query.filter(DagRun.execution_date >= start_date) |
| |
| if end_date is not None: |
| dag_runs_query = dag_runs_query.filter(DagRun.execution_date <= end_date) |
| |
| locked_dag_run_ids = dag_runs_query.all() |
| |
| altered = set_state( |
| tasks=tasks_to_set_state, |
| execution_date=execution_date, |
| run_id=run_id, |
| upstream=upstream, |
| downstream=downstream, |
| future=future, |
| past=past, |
| state=state, |
| commit=commit, |
| session=session, |
| ) |
| |
| if not commit: |
| del locked_dag_run_ids |
| return altered |
| |
| # Clear downstream tasks that are in failed/upstream_failed state to resume them. |
| # Flush the session so that the tasks marked success are reflected in the db. |
| session.flush() |
| task_subset = self.partial_subset( |
| task_ids_or_regex=task_ids, |
| include_downstream=True, |
| include_upstream=False, |
| ) |
| |
| task_subset.clear( |
| start_date=start_date, |
| end_date=end_date, |
| include_subdags=True, |
| include_parentdag=True, |
| only_failed=True, |
| session=session, |
| # Exclude the task from the current group from being cleared |
| exclude_task_ids=frozenset(task_ids), |
| ) |
| |
| del locked_dag_run_ids |
| return altered |
| |
| @property |
| def roots(self) -> list[Operator]: |
| """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" |
| return [task for task in self.tasks if not task.upstream_list] |
| |
| @property |
| def leaves(self) -> list[Operator]: |
| """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" |
| return [task for task in self.tasks if not task.downstream_list] |
| |
| def topological_sort(self, include_subdag_tasks: bool = False): |
| """ |
| Sorts tasks in topographical order, such that a task comes after any of its |
| upstream dependencies. |
| |
| Deprecated in place of ``task_group.topological_sort`` |
| """ |
| from airflow.utils.task_group import TaskGroup |
| |
| def nested_topo(group): |
| for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks): |
| if isinstance(node, TaskGroup): |
| yield from nested_topo(node) |
| else: |
| yield node |
| |
| return tuple(nested_topo(self.task_group)) |
| |
| @provide_session |
| def set_dag_runs_state( |
| self, |
| state: str = State.RUNNING, |
| session: Session = NEW_SESSION, |
| start_date: datetime | None = None, |
| end_date: datetime | None = None, |
| dag_ids: list[str] = [], |
| ) -> None: |
| warnings.warn( |
| "This method is deprecated and will be removed in a future version.", |
| RemovedInAirflow3Warning, |
| stacklevel=3, |
| ) |
| dag_ids = dag_ids or [self.dag_id] |
| query = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)) |
| if start_date: |
| query = query.filter(DagRun.execution_date >= start_date) |
| if end_date: |
| query = query.filter(DagRun.execution_date <= end_date) |
| query.update({DagRun.state: state}, synchronize_session="fetch") |
| |
| @provide_session |
| def clear( |
| self, |
| task_ids: Collection[str | tuple[str, int]] | None = None, |
| start_date: datetime | None = None, |
| end_date: datetime | None = None, |
| only_failed: bool = False, |
| only_running: bool = False, |
| confirm_prompt: bool = False, |
| include_subdags: bool = True, |
| include_parentdag: bool = True, |
| dag_run_state: DagRunState = DagRunState.QUEUED, |
| dry_run: bool = False, |
| session: Session = NEW_SESSION, |
| get_tis: bool = False, |
| recursion_depth: int = 0, |
| max_recursion_depth: int | None = None, |
| dag_bag: DagBag | None = None, |
| exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), |
| ) -> int | Iterable[TaskInstance]: |
| """ |
| Clears a set of task instances associated with the current dag for |
| a specified date range. |
| |
| :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear |
| :param start_date: The minimum execution_date to clear |
| :param end_date: The maximum execution_date to clear |
| :param only_failed: Only clear failed tasks |
| :param only_running: Only clear running tasks. |
| :param confirm_prompt: Ask for confirmation |
| :param include_subdags: Clear tasks in subdags and clear external tasks |
| indicated by ExternalTaskMarker |
| :param include_parentdag: Clear tasks in the parent dag of the subdag. |
| :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not |
| be changed. |
| :param dry_run: Find the tasks to clear but don't clear them. |
| :param session: The sqlalchemy session to use |
| :param dag_bag: The DagBag used to find the dags subdags (Optional) |
| :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) |
| tuples that should not be cleared |
| """ |
| if get_tis: |
| warnings.warn( |
| "Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| dry_run = True |
| |
| if recursion_depth: |
| warnings.warn( |
| "Passing `recursion_depth` to dag.clear() is deprecated.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| if max_recursion_depth: |
| warnings.warn( |
| "Passing `max_recursion_depth` to dag.clear() is deprecated.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| |
| state = [] |
| if only_failed: |
| state += [State.FAILED, State.UPSTREAM_FAILED] |
| if only_running: |
| # Yes, having `+=` doesn't make sense, but this was the existing behaviour |
| state += [State.RUNNING] |
| |
| tis = self._get_task_instances( |
| task_ids=task_ids, |
| start_date=start_date, |
| end_date=end_date, |
| run_id=None, |
| state=state, |
| include_subdags=include_subdags, |
| include_parentdag=include_parentdag, |
| include_dependent_dags=include_subdags, # compat, yes this is not a typo |
| session=session, |
| dag_bag=dag_bag, |
| exclude_task_ids=exclude_task_ids, |
| ) |
| |
| if dry_run: |
| return tis |
| |
| tis = list(tis) |
| |
| count = len(tis) |
| do_it = True |
| if count == 0: |
| return 0 |
| if confirm_prompt: |
| ti_list = "\n".join(str(t) for t in tis) |
| question = ( |
| "You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" |
| ).format(count=count, ti_list=ti_list) |
| do_it = utils.helpers.ask_yesno(question) |
| |
| if do_it: |
| clear_task_instances( |
| tis, |
| session, |
| dag=self, |
| dag_run_state=dag_run_state, |
| ) |
| else: |
| count = 0 |
| print("Cancelled, nothing was cleared.") |
| |
| session.flush() |
| return count |
| |
| @classmethod |
| def clear_dags( |
| cls, |
| dags, |
| start_date=None, |
| end_date=None, |
| only_failed=False, |
| only_running=False, |
| confirm_prompt=False, |
| include_subdags=True, |
| include_parentdag=False, |
| dag_run_state=DagRunState.QUEUED, |
| dry_run=False, |
| ): |
| all_tis = [] |
| for dag in dags: |
| tis = dag.clear( |
| start_date=start_date, |
| end_date=end_date, |
| only_failed=only_failed, |
| only_running=only_running, |
| confirm_prompt=False, |
| include_subdags=include_subdags, |
| include_parentdag=include_parentdag, |
| dag_run_state=dag_run_state, |
| dry_run=True, |
| ) |
| all_tis.extend(tis) |
| |
| if dry_run: |
| return all_tis |
| |
| count = len(all_tis) |
| do_it = True |
| if count == 0: |
| print("Nothing to clear.") |
| return 0 |
| if confirm_prompt: |
| ti_list = "\n".join(str(t) for t in all_tis) |
| question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" |
| do_it = utils.helpers.ask_yesno(question) |
| |
| if do_it: |
| for dag in dags: |
| dag.clear( |
| start_date=start_date, |
| end_date=end_date, |
| only_failed=only_failed, |
| only_running=only_running, |
| confirm_prompt=False, |
| include_subdags=include_subdags, |
| dag_run_state=dag_run_state, |
| dry_run=False, |
| ) |
| else: |
| count = 0 |
| print("Cancelled, nothing was cleared.") |
| return count |
| |
| def __deepcopy__(self, memo): |
| # Switcharoo to go around deepcopying objects coming through the |
| # backdoor |
| cls = self.__class__ |
| result = cls.__new__(cls) |
| memo[id(self)] = result |
| for k, v in self.__dict__.items(): |
| if k not in ("user_defined_macros", "user_defined_filters", "_log"): |
| setattr(result, k, copy.deepcopy(v, memo)) |
| |
| result.user_defined_macros = self.user_defined_macros |
| result.user_defined_filters = self.user_defined_filters |
| if hasattr(self, "_log"): |
| result._log = self._log |
| return result |
| |
| def sub_dag(self, *args, **kwargs): |
| """This method is deprecated in favor of partial_subset.""" |
| warnings.warn( |
| "This method is deprecated and will be removed in a future version. Please use partial_subset", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return self.partial_subset(*args, **kwargs) |
| |
| def partial_subset( |
| self, |
| task_ids_or_regex: str | re.Pattern | Iterable[str], |
| include_downstream=False, |
| include_upstream=True, |
| include_direct_upstream=False, |
| ): |
| """ |
| Returns a subset of the current dag as a deep copy of the current dag |
| based on a regex that should match one or many tasks, and includes |
| upstream and downstream neighbours based on the flag passed. |
| |
| :param task_ids_or_regex: Either a list of task_ids, or a regex to |
| match against task ids (as a string, or compiled regex pattern). |
| :param include_downstream: Include all downstream tasks of matched |
| tasks, in addition to matched tasks. |
| :param include_upstream: Include all upstream tasks of matched tasks, |
| in addition to matched tasks. |
| :param include_direct_upstream: Include all tasks directly upstream of matched |
| and downstream (if include_downstream = True) tasks |
| """ |
| from airflow.models.baseoperator import BaseOperator |
| from airflow.models.mappedoperator import MappedOperator |
| |
| # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all |
| # the tasks anyway, so we copy the tasks manually later |
| memo = {id(self.task_dict): None, id(self._task_group): None} |
| dag = copy.deepcopy(self, memo) # type: ignore |
| |
| if isinstance(task_ids_or_regex, (str, re.Pattern)): |
| matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)] |
| else: |
| matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] |
| |
| also_include: list[Operator] = [] |
| for t in matched_tasks: |
| if include_downstream: |
| also_include.extend(t.get_flat_relatives(upstream=False)) |
| if include_upstream: |
| also_include.extend(t.get_flat_relatives(upstream=True)) |
| |
| direct_upstreams: list[Operator] = [] |
| if include_direct_upstream: |
| for t in itertools.chain(matched_tasks, also_include): |
| upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) |
| direct_upstreams.extend(upstream) |
| |
| # Compiling the unique list of tasks that made the cut |
| # Make sure to not recursively deepcopy the dag or task_group while copying the task. |
| # task_group is reset later |
| def _deepcopy_task(t) -> Operator: |
| memo.setdefault(id(t.task_group), None) |
| return copy.deepcopy(t, memo) |
| |
| dag.task_dict = { |
| t.task_id: _deepcopy_task(t) |
| for t in itertools.chain(matched_tasks, also_include, direct_upstreams) |
| } |
| |
| def filter_task_group(group, parent_group): |
| """Exclude tasks not included in the subdag from the given TaskGroup.""" |
| # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy |
| # and then manually deep copy the instances. (memo argument to deepcopy only works for instances |
| # of classes, not "native" properties of an instance) |
| copied = copy.copy(group) |
| |
| memo[id(group.children)] = {} |
| if parent_group: |
| memo[id(group.parent_group)] = parent_group |
| for attr, value in copied.__dict__.items(): |
| if id(value) in memo: |
| value = memo[id(value)] |
| else: |
| value = copy.deepcopy(value, memo) |
| copied.__dict__[attr] = value |
| |
| proxy = weakref.proxy(copied) |
| |
| for child in group.children.values(): |
| if isinstance(child, AbstractOperator): |
| if child.task_id in dag.task_dict: |
| task = copied.children[child.task_id] = dag.task_dict[child.task_id] |
| task.task_group = proxy |
| else: |
| copied.used_group_ids.discard(child.task_id) |
| else: |
| filtered_child = filter_task_group(child, proxy) |
| |
| # Only include this child TaskGroup if it is non-empty. |
| if filtered_child.children: |
| copied.children[child.group_id] = filtered_child |
| |
| return copied |
| |
| dag._task_group = filter_task_group(self.task_group, None) |
| |
| # Removing upstream/downstream references to tasks and TaskGroups that did not make |
| # the cut. |
| subdag_task_groups = dag.task_group.get_task_group_dict() |
| for group in subdag_task_groups.values(): |
| group.upstream_group_ids.intersection_update(subdag_task_groups) |
| group.downstream_group_ids.intersection_update(subdag_task_groups) |
| group.upstream_task_ids.intersection_update(dag.task_dict) |
| group.downstream_task_ids.intersection_update(dag.task_dict) |
| |
| for t in dag.tasks: |
| # Removing upstream/downstream references to tasks that did not |
| # make the cut |
| t.upstream_task_ids.intersection_update(dag.task_dict) |
| t.downstream_task_ids.intersection_update(dag.task_dict) |
| |
| if len(dag.tasks) < len(self.tasks): |
| dag.partial = True |
| |
| return dag |
| |
| def has_task(self, task_id: str): |
| return task_id in self.task_dict |
| |
| def has_task_group(self, task_group_id: str) -> bool: |
| return task_group_id in self.task_group_dict |
| |
| @cached_property |
| def task_group_dict(self): |
| return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} |
| |
| def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: |
| if task_id in self.task_dict: |
| return self.task_dict[task_id] |
| if include_subdags: |
| for dag in self.subdags: |
| if task_id in dag.task_dict: |
| return dag.task_dict[task_id] |
| raise TaskNotFound(f"Task {task_id} not found") |
| |
| def pickle_info(self): |
| d = {} |
| d["is_picklable"] = True |
| try: |
| dttm = timezone.utcnow() |
| pickled = pickle.dumps(self) |
| d["pickle_len"] = len(pickled) |
| d["pickling_duration"] = str(timezone.utcnow() - dttm) |
| except Exception as e: |
| self.log.debug(e) |
| d["is_picklable"] = False |
| d["stacktrace"] = traceback.format_exc() |
| return d |
| |
| @provide_session |
| def pickle(self, session=NEW_SESSION) -> DagPickle: |
| dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first() |
| dp = None |
| if dag and dag.pickle_id: |
| dp = session.query(DagPickle).filter(DagPickle.id == dag.pickle_id).first() |
| if not dp or dp.pickle != self: |
| dp = DagPickle(dag=self) |
| session.add(dp) |
| self.last_pickled = timezone.utcnow() |
| session.commit() |
| self.pickle_id = dp.id |
| |
| return dp |
| |
| def tree_view(self) -> None: |
| """Print an ASCII tree representation of the DAG.""" |
| |
| def get_downstream(task, level=0): |
| print((" " * level * 4) + str(task)) |
| level += 1 |
| for t in task.downstream_list: |
| get_downstream(t, level) |
| |
| for t in self.roots: |
| get_downstream(t) |
| |
| @property |
| def task(self) -> TaskDecoratorCollection: |
| from airflow.decorators import task |
| |
| return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) |
| |
| def add_task(self, task: Operator) -> None: |
| """ |
| Add a task to the DAG. |
| |
| :param task: the task you want to add |
| """ |
| DagInvalidTriggerRule.check(self, task.trigger_rule) |
| |
| from airflow.utils.task_group import TaskGroupContext |
| |
| if not self.start_date and not task.start_date: |
| raise AirflowException("DAG is missing the start_date parameter") |
| # if the task has no start date, assign it the same as the DAG |
| elif not task.start_date: |
| task.start_date = self.start_date |
| # otherwise, the task will start on the later of its own start date and |
| # the DAG's start date |
| elif self.start_date: |
| task.start_date = max(task.start_date, self.start_date) |
| |
| # if the task has no end date, assign it the same as the dag |
| if not task.end_date: |
| task.end_date = self.end_date |
| # otherwise, the task will end on the earlier of its own end date and |
| # the DAG's end date |
| elif task.end_date and self.end_date: |
| task.end_date = min(task.end_date, self.end_date) |
| |
| task_id = task.task_id |
| if not task.task_group: |
| task_group = TaskGroupContext.get_current_task_group(self) |
| if task_group: |
| task_id = task_group.child_id(task_id) |
| task_group.add(task) |
| |
| if ( |
| task_id in self.task_dict and self.task_dict[task_id] is not task |
| ) or task_id in self._task_group.used_group_ids: |
| raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") |
| else: |
| self.task_dict[task_id] = task |
| task.dag = self |
| # Add task_id to used_group_ids to prevent group_id and task_id collisions. |
| self._task_group.used_group_ids.add(task_id) |
| |
| self.task_count = len(self.task_dict) |
| |
| def add_tasks(self, tasks: Iterable[Operator]) -> None: |
| """ |
| Add a list of tasks to the DAG. |
| |
| :param tasks: a lit of tasks you want to add |
| """ |
| for task in tasks: |
| self.add_task(task) |
| |
| def _remove_task(self, task_id: str) -> None: |
| # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this |
| # doesn't guard against that |
| task = self.task_dict.pop(task_id) |
| tg = getattr(task, "task_group", None) |
| if tg: |
| tg._remove(task) |
| |
| self.task_count = len(self.task_dict) |
| |
| def run( |
| self, |
| start_date=None, |
| end_date=None, |
| mark_success=False, |
| local=False, |
| executor=None, |
| donot_pickle=conf.getboolean("core", "donot_pickle"), |
| ignore_task_deps=False, |
| ignore_first_depends_on_past=True, |
| pool=None, |
| delay_on_limit_secs=1.0, |
| verbose=False, |
| conf=None, |
| rerun_failed_tasks=False, |
| run_backwards=False, |
| run_at_least_once=False, |
| continue_on_failures=False, |
| disable_retry=False, |
| ): |
| """ |
| Runs the DAG. |
| |
| :param start_date: the start date of the range to run |
| :param end_date: the end date of the range to run |
| :param mark_success: True to mark jobs as succeeded without running them |
| :param local: True to run the tasks using the LocalExecutor |
| :param executor: The executor instance to run the tasks |
| :param donot_pickle: True to avoid pickling DAG object and send to workers |
| :param ignore_task_deps: True to skip upstream tasks |
| :param ignore_first_depends_on_past: True to ignore depends_on_past |
| dependencies for the first set of tasks only |
| :param pool: Resource pool to use |
| :param delay_on_limit_secs: Time in seconds to wait before next attempt to run |
| dag run when max_active_runs limit has been reached |
| :param verbose: Make logging output more verbose |
| :param conf: user defined dictionary passed from CLI |
| :param rerun_failed_tasks: |
| :param run_backwards: |
| :param run_at_least_once: If true, always run the DAG at least once even |
| if no logical run exists within the time range. |
| """ |
| from airflow.jobs.backfill_job_runner import BackfillJobRunner |
| |
| if not executor and local: |
| from airflow.executors.local_executor import LocalExecutor |
| |
| executor = LocalExecutor() |
| elif not executor: |
| from airflow.executors.executor_loader import ExecutorLoader |
| |
| executor = ExecutorLoader.get_default_executor() |
| from airflow.jobs.job import Job |
| |
| job = Job(executor=executor) |
| job_runner = BackfillJobRunner( |
| job=job, |
| dag=self, |
| start_date=start_date, |
| end_date=end_date, |
| mark_success=mark_success, |
| donot_pickle=donot_pickle, |
| ignore_task_deps=ignore_task_deps, |
| ignore_first_depends_on_past=ignore_first_depends_on_past, |
| pool=pool, |
| delay_on_limit_secs=delay_on_limit_secs, |
| verbose=verbose, |
| conf=conf, |
| rerun_failed_tasks=rerun_failed_tasks, |
| run_backwards=run_backwards, |
| run_at_least_once=run_at_least_once, |
| continue_on_failures=continue_on_failures, |
| disable_retry=disable_retry, |
| ) |
| run_job(job=job, execute_callable=job_runner._execute) |
| |
| def cli(self): |
| """Exposes a CLI specific to this DAG.""" |
| check_cycle(self) |
| |
| from airflow.cli import cli_parser |
| |
| parser = cli_parser.get_parser(dag_parser=True) |
| args = parser.parse_args() |
| args.func(args, self) |
| |
| @provide_session |
| def test( |
| self, |
| execution_date: datetime | None = None, |
| run_conf: dict[str, Any] | None = None, |
| conn_file_path: str | None = None, |
| variable_file_path: str | None = None, |
| session: Session = NEW_SESSION, |
| ) -> None: |
| """ |
| Execute one single DagRun for a given DAG and execution date. |
| |
| :param execution_date: execution date for the DAG run |
| :param run_conf: configuration to pass to newly created dagrun |
| :param conn_file_path: file path to a connection file in either yaml or json |
| :param variable_file_path: file path to a variable file in either yaml or json |
| :param session: database connection (optional) |
| """ |
| |
| def add_logger_if_needed(ti: TaskInstance): |
| """Add a formatted logger to the task instance. |
| |
| This allows all logs to surface to the command line, instead of into |
| a task file. Since this is a local test run, it is much better for |
| the user to see logs in the command line, rather than needing to |
| search for a log file. |
| |
| :param ti: The task instance that will receive a logger. |
| """ |
| format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") |
| handler = logging.StreamHandler(sys.stdout) |
| handler.level = logging.INFO |
| handler.setFormatter(format) |
| # only add log handler once |
| if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): |
| self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) |
| ti.log.addHandler(handler) |
| |
| if conn_file_path or variable_file_path: |
| local_secrets = LocalFilesystemBackend( |
| variables_file_path=variable_file_path, connections_file_path=conn_file_path |
| ) |
| secrets_backend_list.insert(0, local_secrets) |
| |
| execution_date = execution_date or timezone.utcnow() |
| self.log.debug("Clearing existing task instances for execution date %s", execution_date) |
| self.clear( |
| start_date=execution_date, |
| end_date=execution_date, |
| dag_run_state=False, # type: ignore |
| session=session, |
| ) |
| self.log.debug("Getting dagrun for dag %s", self.dag_id) |
| dr: DagRun = _get_or_create_dagrun( |
| dag=self, |
| start_date=execution_date, |
| execution_date=execution_date, |
| run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), |
| session=session, |
| conf=run_conf, |
| ) |
| |
| tasks = self.task_dict |
| self.log.debug("starting dagrun") |
| # Instead of starting a scheduler, we run the minimal loop possible to check |
| # for task readiness and dependency management. This is notably faster |
| # than creating a BackfillJob and allows us to surface logs to the user |
| while dr.state == State.RUNNING: |
| schedulable_tis, _ = dr.update_state(session=session) |
| for ti in schedulable_tis: |
| add_logger_if_needed(ti) |
| ti.task = tasks[ti.task_id] |
| _run_task(ti, session=session) |
| if conn_file_path or variable_file_path: |
| # Remove the local variables we have added to the secrets_backend_list |
| secrets_backend_list.pop(0) |
| |
| @provide_session |
| def create_dagrun( |
| self, |
| state: DagRunState, |
| execution_date: datetime | None = None, |
| run_id: str | None = None, |
| start_date: datetime | None = None, |
| external_trigger: bool | None = False, |
| conf: dict | None = None, |
| run_type: DagRunType | None = None, |
| session: Session = NEW_SESSION, |
| dag_hash: str | None = None, |
| creating_job_id: int | None = None, |
| data_interval: tuple[datetime, datetime] | None = None, |
| ): |
| """ |
| Creates a dag run from this dag including the tasks associated with this dag. |
| Returns the dag run. |
| |
| :param run_id: defines the run id for this dag run |
| :param run_type: type of DagRun |
| :param execution_date: the execution date of this dag run |
| :param state: the state of the dag run |
| :param start_date: the date this dag run should be evaluated |
| :param external_trigger: whether this dag run is externally triggered |
| :param conf: Dict containing configuration/parameters to pass to the DAG |
| :param creating_job_id: id of the job creating this DagRun |
| :param session: database session |
| :param dag_hash: Hash of Serialized DAG |
| :param data_interval: Data interval of the DagRun |
| """ |
| logical_date = timezone.coerce_datetime(execution_date) |
| |
| if data_interval and not isinstance(data_interval, DataInterval): |
| data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) |
| |
| if data_interval is None and logical_date is not None: |
| warnings.warn( |
| "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", |
| RemovedInAirflow3Warning, |
| stacklevel=3, |
| ) |
| if run_type == DagRunType.MANUAL: |
| data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) |
| else: |
| data_interval = self.infer_automated_data_interval(logical_date) |
| |
| if run_type is None or isinstance(run_type, DagRunType): |
| pass |
| elif isinstance(run_type, str): # Compatibility: run_type used to be a str. |
| run_type = DagRunType(run_type) |
| else: |
| raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}") |
| |
| if run_id: # Infer run_type from run_id if needed. |
| if not isinstance(run_id, str): |
| raise ValueError(f"`run_id` should be a str, not {type(run_id)}") |
| inferred_run_type = DagRunType.from_run_id(run_id) |
| if run_type is None: |
| # No explicit type given, use the inferred type. |
| run_type = inferred_run_type |
| elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL: |
| # Prevent a manual run from using an ID that looks like a scheduled run. |
| raise ValueError( |
| f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " |
| f"is reserved for {inferred_run_type.value} runs" |
| ) |
| elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date. |
| run_id = self.timetable.generate_run_id( |
| run_type=run_type, logical_date=logical_date, data_interval=data_interval |
| ) |
| else: |
| raise AirflowException( |
| "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" |
| ) |
| |
| if run_id and "/" in run_id: |
| warnings.warn( |
| "Using forward slash ('/') in a DAG run ID is deprecated. Note that this character " |
| "also makes the run impossible to retrieve via Airflow's REST API.", |
| RemovedInAirflow3Warning, |
| stacklevel=3, |
| ) |
| |
| # create a copy of params before validating |
| copied_params = copy.deepcopy(self.params) |
| copied_params.update(conf or {}) |
| copied_params.validate() |
| |
| run = DagRun( |
| dag_id=self.dag_id, |
| run_id=run_id, |
| execution_date=logical_date, |
| start_date=start_date, |
| external_trigger=external_trigger, |
| conf=conf, |
| state=state, |
| run_type=run_type, |
| dag_hash=dag_hash, |
| creating_job_id=creating_job_id, |
| data_interval=data_interval, |
| ) |
| session.add(run) |
| session.flush() |
| |
| run.dag = self |
| |
| # create the associated task instances |
| # state is None at the moment of creation |
| run.verify_integrity(session=session) |
| |
| return run |
| |
| @classmethod |
| @provide_session |
| def bulk_sync_to_db( |
| cls, |
| dags: Collection[DAG], |
| session=NEW_SESSION, |
| ): |
| """This method is deprecated in favor of bulk_write_to_db.""" |
| warnings.warn( |
| "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| return cls.bulk_write_to_db(dags=dags, session=session) |
| |
| @classmethod |
| @provide_session |
| def bulk_write_to_db( |
| cls, |
| dags: Collection[DAG], |
| processor_subdir: str | None = None, |
| session=NEW_SESSION, |
| ): |
| """ |
| Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including |
| calculated fields. |
| |
| Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. |
| |
| :param dags: the DAG objects to save to the DB |
| :return: None |
| """ |
| if not dags: |
| return |
| |
| log.info("Sync %s DAGs", len(dags)) |
| dag_by_ids = {dag.dag_id: dag for dag in dags} |
| |
| dag_ids = set(dag_by_ids.keys()) |
| query = ( |
| session.query(DagModel) |
| .options(joinedload(DagModel.tags, innerjoin=False)) |
| .filter(DagModel.dag_id.in_(dag_ids)) |
| .options(joinedload(DagModel.schedule_dataset_references)) |
| .options(joinedload(DagModel.task_outlet_dataset_references)) |
| ) |
| orm_dags: list[DagModel] = with_row_locks(query, of=DagModel, session=session).all() |
| existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags} |
| missing_dag_ids = dag_ids.difference(existing_dags) |
| |
| for missing_dag_id in missing_dag_ids: |
| orm_dag = DagModel(dag_id=missing_dag_id) |
| dag = dag_by_ids[missing_dag_id] |
| if dag.is_paused_upon_creation is not None: |
| orm_dag.is_paused = dag.is_paused_upon_creation |
| orm_dag.tags = [] |
| log.info("Creating ORM DAG for %s", dag.dag_id) |
| session.add(orm_dag) |
| orm_dags.append(orm_dag) |
| |
| # Get the latest dag run for each existing dag as a single query (avoid n+1 query) |
| most_recent_subq = ( |
| session.query(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) |
| .filter( |
| DagRun.dag_id.in_(existing_dags), |
| or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), |
| ) |
| .group_by(DagRun.dag_id) |
| .subquery() |
| ) |
| most_recent_runs_iter = session.query(DagRun).filter( |
| DagRun.dag_id == most_recent_subq.c.dag_id, |
| DagRun.execution_date == most_recent_subq.c.max_execution_date, |
| ) |
| most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter} |
| |
| # Get number of active dagruns for all dags we are processing as a single query. |
| |
| num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) |
| |
| filelocs = [] |
| |
| for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): |
| dag = dag_by_ids[orm_dag.dag_id] |
| filelocs.append(dag.fileloc) |
| if dag.is_subdag: |
| orm_dag.is_subdag = True |
| orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore |
| orm_dag.root_dag_id = dag.parent_dag.dag_id # type: ignore |
| orm_dag.owners = dag.parent_dag.owner # type: ignore |
| else: |
| orm_dag.is_subdag = False |
| orm_dag.fileloc = dag.fileloc |
| orm_dag.owners = dag.owner |
| orm_dag.is_active = True |
| orm_dag.has_import_errors = False |
| orm_dag.last_parsed_time = timezone.utcnow() |
| orm_dag.default_view = dag.default_view |
| orm_dag.description = dag.description |
| orm_dag.max_active_tasks = dag.max_active_tasks |
| orm_dag.max_active_runs = dag.max_active_runs |
| orm_dag.has_task_concurrency_limits = any( |
| t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None |
| for t in dag.tasks |
| ) |
| orm_dag.schedule_interval = dag.schedule_interval |
| orm_dag.timetable_description = dag.timetable.description |
| orm_dag.processor_subdir = processor_subdir |
| |
| run: DagRun | None = most_recent_runs.get(dag.dag_id) |
| if run is None: |
| data_interval = None |
| else: |
| data_interval = dag.get_run_data_interval(run) |
| if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs: |
| orm_dag.next_dagrun_create_after = None |
| else: |
| orm_dag.calculate_dagrun_date_fields(dag, data_interval) |
| |
| dag_tags = set(dag.tags or {}) |
| orm_dag_tags = list(orm_dag.tags or []) |
| for orm_tag in orm_dag_tags: |
| if orm_tag.name not in dag_tags: |
| session.delete(orm_tag) |
| orm_dag.tags.remove(orm_tag) |
| orm_tag_names = {t.name for t in orm_dag_tags} |
| for dag_tag in dag_tags: |
| if dag_tag not in orm_tag_names: |
| dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id) |
| orm_dag.tags.append(dag_tag_orm) |
| session.add(dag_tag_orm) |
| |
| orm_dag_links = orm_dag.dag_owner_links or [] |
| for orm_dag_link in orm_dag_links: |
| if orm_dag_link not in dag.owner_links: |
| session.delete(orm_dag_link) |
| for owner_name, owner_link in dag.owner_links.items(): |
| dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) |
| session.add(dag_owner_orm) |
| |
| DagCode.bulk_sync_to_db(filelocs, session=session) |
| |
| from airflow.datasets import Dataset |
| from airflow.models.dataset import ( |
| DagScheduleDatasetReference, |
| DatasetModel, |
| TaskOutletDatasetReference, |
| ) |
| |
| dag_references = collections.defaultdict(set) |
| outlet_references = collections.defaultdict(set) |
| # We can't use a set here as we want to preserve order |
| outlet_datasets: dict[Dataset, None] = {} |
| input_datasets: dict[Dataset, None] = {} |
| |
| # here we go through dags and tasks to check for dataset references |
| # if there are now None and previously there were some, we delete them |
| # if there are now *any*, we add them to the above data structures, and |
| # later we'll persist them to the database. |
| for dag in dags: |
| curr_orm_dag = existing_dags.get(dag.dag_id) |
| if not dag.dataset_triggers: |
| if curr_orm_dag and curr_orm_dag.schedule_dataset_references: |
| curr_orm_dag.schedule_dataset_references = [] |
| for dataset in dag.dataset_triggers: |
| dag_references[dag.dag_id].add(dataset.uri) |
| input_datasets[DatasetModel.from_public(dataset)] = None |
| curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references |
| for task in dag.tasks: |
| dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] |
| if not dataset_outlets: |
| if curr_outlet_references: |
| this_task_outlet_refs = [ |
| x |
| for x in curr_outlet_references |
| if x.dag_id == dag.dag_id and x.task_id == task.task_id |
| ] |
| for ref in this_task_outlet_refs: |
| curr_outlet_references.remove(ref) |
| for d in dataset_outlets: |
| outlet_references[(task.dag_id, task.task_id)].add(d.uri) |
| outlet_datasets[DatasetModel.from_public(d)] = None |
| all_datasets = outlet_datasets |
| all_datasets.update(input_datasets) |
| |
| # store datasets |
| stored_datasets = {} |
| for dataset in all_datasets: |
| stored_dataset = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first() |
| if stored_dataset: |
| # Some datasets may have been previously unreferenced, and therefore orphaned by the |
| # scheduler. But if we're here, then we have found that dataset again in our DAGs, which |
| # means that it is no longer an orphan, so set is_orphaned to False. |
| stored_dataset.is_orphaned = expression.false() |
| stored_datasets[stored_dataset.uri] = stored_dataset |
| else: |
| session.add(dataset) |
| stored_datasets[dataset.uri] = dataset |
| |
| session.flush() # this is required to ensure each dataset has its PK loaded |
| |
| del all_datasets |
| |
| # reconcile dag-schedule-on-dataset references |
| for dag_id, uri_list in dag_references.items(): |
| dag_refs_needed = { |
| DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id) |
| for uri in uri_list |
| } |
| dag_refs_stored = set( |
| existing_dags.get(dag_id) |
| and existing_dags.get(dag_id).schedule_dataset_references # type: ignore |
| or [] |
| ) |
| dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored} |
| session.bulk_save_objects(dag_refs_to_add) |
| for obj in dag_refs_stored - dag_refs_needed: |
| session.delete(obj) |
| |
| existing_task_outlet_refs_dict = collections.defaultdict(set) |
| for dag_id, orm_dag in existing_dags.items(): |
| for todr in orm_dag.task_outlet_dataset_references: |
| existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) |
| |
| # reconcile task-outlet-dataset references |
| for (dag_id, task_id), uri_list in outlet_references.items(): |
| task_refs_needed = { |
| TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) |
| for uri in uri_list |
| } |
| task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] |
| task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} |
| session.bulk_save_objects(task_refs_to_add) |
| for obj in task_refs_stored - task_refs_needed: |
| session.delete(obj) |
| |
| # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller |
| # decide when to commit |
| session.flush() |
| |
| for dag in dags: |
| cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) |
| |
| @provide_session |
| def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): |
| """ |
| Save attributes about this DAG to the DB. Note that this method |
| can be called for both DAGs and SubDAGs. A SubDag is actually a |
| SubDagOperator. |
| |
| :return: None |
| """ |
| self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session) |
| |
| def get_default_view(self): |
| """This is only there for backward compatible jinja2 templates.""" |
| if self.default_view is None: |
| return conf.get("webserver", "dag_default_view").lower() |
| else: |
| return self.default_view |
| |
| @staticmethod |
| @provide_session |
| def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION): |
| """ |
| Given a list of known DAGs, deactivate any other DAGs that are |
| marked as active in the ORM. |
| |
| :param active_dag_ids: list of DAG IDs that are active |
| :return: None |
| """ |
| if len(active_dag_ids) == 0: |
| return |
| for dag in session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): |
| dag.is_active = False |
| session.merge(dag) |
| session.commit() |
| |
| @staticmethod |
| @provide_session |
| def deactivate_stale_dags(expiration_date, session=NEW_SESSION): |
| """ |
| Deactivate any DAGs that were last touched by the scheduler before |
| the expiration date. These DAGs were likely deleted. |
| |
| :param expiration_date: set inactive DAGs that were touched before this |
| time |
| :return: None |
| """ |
| for dag in ( |
| session.query(DagModel) |
| .filter(DagModel.last_parsed_time < expiration_date, DagModel.is_active) |
| .all() |
| ): |
| log.info( |
| "Deactivating DAG ID %s since it was last touched by the scheduler at %s", |
| dag.dag_id, |
| dag.last_parsed_time.isoformat(), |
| ) |
| dag.is_active = False |
| session.merge(dag) |
| session.commit() |
| |
| @staticmethod |
| @provide_session |
| def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, session=NEW_SESSION) -> int: |
| """ |
| Returns the number of task instances in the given DAG. |
| |
| :param session: ORM session |
| :param dag_id: ID of the DAG to get the task concurrency of |
| :param run_id: ID of the DAG run to get the task concurrency of |
| :param task_ids: A list of valid task IDs for the given DAG |
| :param states: A list of states to filter by if supplied |
| :return: The number of running tasks |
| """ |
| qry = session.query(func.count(TaskInstance.task_id)).filter( |
| TaskInstance.dag_id == dag_id, |
| ) |
| if run_id: |
| qry = qry.filter( |
| TaskInstance.run_id == run_id, |
| ) |
| if task_ids: |
| qry = qry.filter( |
| TaskInstance.task_id.in_(task_ids), |
| ) |
| |
| if states: |
| if None in states: |
| if all(x is None for x in states): |
| qry = qry.filter(TaskInstance.state.is_(None)) |
| else: |
| not_none_states = [state for state in states if state] |
| qry = qry.filter( |
| or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None)) |
| ) |
| else: |
| qry = qry.filter(TaskInstance.state.in_(states)) |
| return qry.scalar() |
| |
| @classmethod |
| def get_serialized_fields(cls): |
| """Stringified DAGs and operators contain exactly these fields.""" |
| if not cls.__serialized_fields: |
| exclusion_list = { |
| "parent_dag", |
| "schedule_dataset_references", |
| "task_outlet_dataset_references", |
| "_old_context_manager_dags", |
| "safe_dag_id", |
| "last_loaded", |
| "user_defined_filters", |
| "user_defined_macros", |
| "partial", |
| "params", |
| "_pickle_id", |
| "_log", |
| "task_dict", |
| "template_searchpath", |
| "sla_miss_callback", |
| "on_success_callback", |
| "on_failure_callback", |
| "template_undefined", |
| "jinja_environment_kwargs", |
| # has_on_*_callback are only stored if the value is True, as the default is False |
| "has_on_success_callback", |
| "has_on_failure_callback", |
| "auto_register", |
| "fail_stop", |
| } |
| cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list |
| return cls.__serialized_fields |
| |
| def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: |
| """ |
| Returns edge information for the given pair of tasks if present, and |
| an empty edge if there is no information. |
| """ |
| # Note - older serialized DAGs may not have edge_info being a dict at all |
| empty = cast(EdgeInfoType, {}) |
| if self.edge_info: |
| return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) |
| else: |
| return empty |
| |
| def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): |
| """ |
| Sets the given edge information on the DAG. Note that this will overwrite, |
| rather than merge with, existing info. |
| """ |
| self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info |
| |
| def validate_schedule_and_params(self): |
| """ |
| Validates & raise exception if there are any Params in the DAG which neither have a default value nor |
| have the null in schema['type'] list, but the DAG have a schedule_interval which is not None. |
| """ |
| if not self.timetable.can_run: |
| return |
| |
| for k, v in self.params.items(): |
| # As type can be an array, we would check if `null` is an allowed type or not |
| if not v.has_value and ("type" not in v.schema or "null" not in v.schema["type"]): |
| raise AirflowException( |
| "DAG Schedule must be None, if there are any required params without default values" |
| ) |
| |
| def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: |
| """Parses a given link, and verifies if it's a valid URL, or a 'mailto' link. |
| Returns an iterator of invalid (owner, link) pairs. |
| """ |
| for owner, link in self.owner_links.items(): |
| result = urlsplit(link) |
| if result.scheme == "mailto": |
| # netloc is not existing for 'mailto' link, so we are checking that the path is parsed |
| if not result.path: |
| yield result.path, link |
| elif not result.scheme or not result.netloc: |
| yield owner, link |
| |
| |
| class DagTag(Base): |
| """A tag name per dag, to allow quick filtering in the DAG view.""" |
| |
| __tablename__ = "dag_tag" |
| name = Column(String(TAG_MAX_LEN), primary_key=True) |
| dag_id = Column( |
| StringID(), |
| ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), |
| primary_key=True, |
| ) |
| |
| def __repr__(self): |
| return self.name |
| |
| |
| class DagOwnerAttributes(Base): |
| """Table defining different owner attributes. |
| |
| For example, a link for an owner that will be passed as a hyperlink to the |
| "DAGs" view. |
| """ |
| |
| __tablename__ = "dag_owner_attributes" |
| dag_id = Column( |
| StringID(), |
| ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), |
| nullable=False, |
| primary_key=True, |
| ) |
| owner = Column(String(500), primary_key=True, nullable=False) |
| link = Column(String(500), nullable=False) |
| |
| def __repr__(self): |
| return f"<DagOwnerAttributes: dag_id={self.dag_id}, owner={self.owner}, link={self.link}>" |
| |
| @classmethod |
| def get_all(cls, session) -> dict[str, dict[str, str]]: |
| dag_links: dict = collections.defaultdict(dict) |
| for obj in session.query(cls): |
| dag_links[obj.dag_id].update({obj.owner: obj.link}) |
| return dag_links |
| |
| |
| class DagModel(Base): |
| """Table containing DAG properties.""" |
| |
| __tablename__ = "dag" |
| """ |
| These items are stored in the database for state related information |
| """ |
| dag_id = Column(StringID(), primary_key=True) |
| root_dag_id = Column(StringID()) |
| # A DAG can be paused from the UI / DB |
| # Set this default value of is_paused based on a configuration value! |
| is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation") |
| is_paused = Column(Boolean, default=is_paused_at_creation) |
| # Whether the DAG is a subdag |
| is_subdag = Column(Boolean, default=False) |
| # Whether that DAG was seen on the last DagBag load |
| is_active = Column(Boolean, default=False) |
| # Last time the scheduler started |
| last_parsed_time = Column(UtcDateTime) |
| # Last time this DAG was pickled |
| last_pickled = Column(UtcDateTime) |
| # Time when the DAG last received a refresh signal |
| # (e.g. the DAG's "refresh" button was clicked in the web UI) |
| last_expired = Column(UtcDateTime) |
| # Whether (one of) the scheduler is scheduling this DAG at the moment |
| scheduler_lock = Column(Boolean) |
| # Foreign key to the latest pickle_id |
| pickle_id = Column(Integer) |
| # The location of the file containing the DAG object |
| # Note: Do not depend on fileloc pointing to a file; in the case of a |
| # packaged DAG, it will point to the subpath of the DAG within the |
| # associated zip. |
| fileloc = Column(String(2000)) |
| # The base directory used by Dag Processor that parsed this dag. |
| processor_subdir = Column(String(2000), nullable=True) |
| # String representing the owners |
| owners = Column(String(2000)) |
| # Description of the dag |
| description = Column(Text) |
| # Default view of the DAG inside the webserver |
| default_view = Column(String(25)) |
| # Schedule interval |
| schedule_interval = Column(Interval) |
| # Timetable/Schedule Interval description |
| timetable_description = Column(String(1000), nullable=True) |
| # Tags for view filter |
| tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) |
| # Dag owner links for DAGs view |
| dag_owner_links = relationship( |
| "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") |
| ) |
| |
| max_active_tasks = Column(Integer, nullable=False) |
| max_active_runs = Column(Integer, nullable=True) |
| |
| has_task_concurrency_limits = Column(Boolean, nullable=False) |
| has_import_errors = Column(Boolean(), default=False, server_default="0") |
| |
| # The logical date of the next dag run. |
| next_dagrun = Column(UtcDateTime) |
| |
| # Must be either both NULL or both datetime. |
| next_dagrun_data_interval_start = Column(UtcDateTime) |
| next_dagrun_data_interval_end = Column(UtcDateTime) |
| |
| # Earliest time at which this ``next_dagrun`` can be created. |
| next_dagrun_create_after = Column(UtcDateTime) |
| |
| __table_args__ = ( |
| Index("idx_root_dag_id", root_dag_id, unique=False), |
| Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False), |
| ) |
| |
| parent_dag = relationship( |
| "DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id] |
| ) |
| schedule_dataset_references = relationship( |
| "DagScheduleDatasetReference", |
| cascade="all, delete, delete-orphan", |
| ) |
| schedule_datasets = association_proxy("schedule_dataset_references", "dataset") |
| task_outlet_dataset_references = relationship( |
| "TaskOutletDatasetReference", |
| cascade="all, delete, delete-orphan", |
| ) |
| NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10) |
| |
| def __init__(self, concurrency=None, **kwargs): |
| super().__init__(**kwargs) |
| if self.max_active_tasks is None: |
| if concurrency: |
| warnings.warn( |
| "The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| self.max_active_tasks = concurrency |
| else: |
| self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag") |
| |
| if self.max_active_runs is None: |
| self.max_active_runs = conf.getint("core", "max_active_runs_per_dag") |
| |
| if self.has_task_concurrency_limits is None: |
| # Be safe -- this will be updated later once the DAG is parsed |
| self.has_task_concurrency_limits = True |
| |
| def __repr__(self): |
| return f"<DAG: {self.dag_id}>" |
| |
| @property |
| def next_dagrun_data_interval(self) -> DataInterval | None: |
| return _get_model_data_interval( |
| self, |
| "next_dagrun_data_interval_start", |
| "next_dagrun_data_interval_end", |
| ) |
| |
| @next_dagrun_data_interval.setter |
| def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None: |
| if value is None: |
| self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None |
| else: |
| self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end = value |
| |
| @property |
| def timezone(self): |
| return settings.TIMEZONE |
| |
| @staticmethod |
| @provide_session |
| def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel | None: |
| return session.get( |
| DagModel, |
| dag_id, |
| options=[joinedload(DagModel.parent_dag)], |
| ) |
| |
| @classmethod |
| @provide_session |
| def get_current(cls, dag_id, session=NEW_SESSION): |
| return session.query(cls).filter(cls.dag_id == dag_id).first() |
| |
| @provide_session |
| def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): |
| return get_last_dagrun( |
| self.dag_id, session=session, include_externally_triggered=include_externally_triggered |
| ) |
| |
| def get_is_paused(self, *, session: Session | None = None) -> bool: |
| """Provide interface compatibility to 'DAG'.""" |
| return self.is_paused |
| |
| @staticmethod |
| @internal_api_call |
| @provide_session |
| def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]: |
| """ |
| Given a list of dag_ids, get a set of Paused Dag Ids. |
| |
| :param dag_ids: List of Dag ids |
| :param session: ORM Session |
| :return: Paused Dag_ids |
| """ |
| paused_dag_ids = ( |
| session.query(DagModel.dag_id) |
| .filter(DagModel.is_paused == expression.true()) |
| .filter(DagModel.dag_id.in_(dag_ids)) |
| .all() |
| ) |
| |
| paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids} |
| return paused_dag_ids |
| |
| def get_default_view(self) -> str: |
| """ |
| Get the Default DAG View, returns the default config value if DagModel does not |
| have a value. |
| """ |
| # This is for backwards-compatibility with old dags that don't have None as default_view |
| return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower() |
| |
| @property |
| def safe_dag_id(self): |
| return self.dag_id.replace(".", "__dot__") |
| |
| @property |
| def relative_fileloc(self) -> pathlib.Path | None: |
| """File location of the importable dag 'file' relative to the configured DAGs folder.""" |
| if self.fileloc is None: |
| return None |
| path = pathlib.Path(self.fileloc) |
| try: |
| return path.relative_to(settings.DAGS_FOLDER) |
| except ValueError: |
| # Not relative to DAGS_FOLDER. |
| return path |
| |
| @provide_session |
| def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None: |
| """ |
| Pause/Un-pause a DAG. |
| |
| :param is_paused: Is the DAG paused |
| :param including_subdags: whether to include the DAG's subdags |
| :param session: session |
| """ |
| filter_query = [ |
| DagModel.dag_id == self.dag_id, |
| ] |
| if including_subdags: |
| filter_query.append(DagModel.root_dag_id == self.dag_id) |
| session.query(DagModel).filter(or_(*filter_query)).update( |
| {DagModel.is_paused: is_paused}, synchronize_session="fetch" |
| ) |
| session.commit() |
| |
| @classmethod |
| @internal_api_call |
| @provide_session |
| def deactivate_deleted_dags(cls, alive_dag_filelocs: list[str], session=NEW_SESSION): |
| """ |
| Set ``is_active=False`` on the DAGs for which the DAG files have been removed. |
| |
| :param alive_dag_filelocs: file paths of alive DAGs |
| :param session: ORM Session |
| """ |
| log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) |
| |
| dag_models = session.query(cls).all() |
| for dag_model in dag_models: |
| if dag_model.fileloc is not None and dag_model.fileloc not in alive_dag_filelocs: |
| dag_model.is_active = False |
| else: |
| continue |
| |
| @classmethod |
| def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: |
| """ |
| Return (and lock) a list of Dag objects that are due to create a new DagRun. |
| |
| This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, |
| you should ensure that any scheduling decisions are made in a single transaction -- as soon as the |
| transaction is committed it will be unlocked. |
| """ |
| from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ |
| |
| # these dag ids are triggered by datasets, and they are ready to go. |
| dataset_triggered_dag_info = { |
| x.dag_id: (x.first_queued_time, x.last_queued_time) |
| for x in session.query( |
| DagScheduleDatasetReference.dag_id, |
| func.max(DDRQ.created_at).label("last_queued_time"), |
| func.min(DDRQ.created_at).label("first_queued_time"), |
| ) |
| .join(DagScheduleDatasetReference.queue_records, isouter=True) |
| .group_by(DagScheduleDatasetReference.dag_id) |
| .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) |
| .all() |
| } |
| dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) |
| if dataset_triggered_dag_ids: |
| exclusion_list = { |
| x.dag_id |
| for x in ( |
| session.query(DagModel.dag_id) |
| .join(DagRun.dag_model) |
| .filter(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) |
| .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids)) |
| .group_by(DagModel.dag_id) |
| .having(func.count() >= func.max(DagModel.max_active_runs)) |
| .all() |
| ) |
| } |
| if exclusion_list: |
| dataset_triggered_dag_ids -= exclusion_list |
| dataset_triggered_dag_info = { |
| k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list |
| } |
| |
| # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs |
| query = ( |
| session.query(cls) |
| .filter( |
| cls.is_paused == expression.false(), |
| cls.is_active == expression.true(), |
| cls.has_import_errors == expression.false(), |
| or_( |
| cls.next_dagrun_create_after <= func.now(), |
| cls.dag_id.in_(dataset_triggered_dag_ids), |
| ), |
| ) |
| .order_by(cls.next_dagrun_create_after) |
| .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) |
| ) |
| |
| return ( |
| with_row_locks(query, of=cls, session=session, **skip_locked(session=session)), |
| dataset_triggered_dag_info, |
| ) |
| |
| def calculate_dagrun_date_fields( |
| self, |
| dag: DAG, |
| most_recent_dag_run: None | datetime | DataInterval, |
| ) -> None: |
| """ |
| Calculate ``next_dagrun`` and `next_dagrun_create_after``. |
| |
| :param dag: The DAG object |
| :param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none |
| if not yet scheduled. |
| """ |
| most_recent_data_interval: DataInterval | None |
| if isinstance(most_recent_dag_run, datetime): |
| warnings.warn( |
| "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. " |
| "Provide a data interval instead.", |
| RemovedInAirflow3Warning, |
| stacklevel=2, |
| ) |
| most_recent_data_interval = dag.infer_automated_data_interval(most_recent_dag_run) |
| else: |
| most_recent_data_interval = most_recent_dag_run |
| next_dagrun_info = dag.next_dagrun_info(most_recent_data_interval) |
| if next_dagrun_info is None: |
| self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None |
| else: |
| self.next_dagrun_data_interval = next_dagrun_info.data_interval |
| self.next_dagrun = next_dagrun_info.logical_date |
| self.next_dagrun_create_after = next_dagrun_info.run_after |
| |
| log.info( |
| "Setting next_dagrun for %s to %s, run_after=%s", |
| dag.dag_id, |
| self.next_dagrun, |
| self.next_dagrun_create_after, |
| ) |
| |
| @provide_session |
| def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: |
| if self.schedule_interval != "Dataset": |
| return None |
| return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id] |
| |
| |
| # NOTE: Please keep the list of arguments in sync with DAG.__init__. |
| # Only exception: dag_id here should have a default value, but not in DAG. |
| def dag( |
| dag_id: str = "", |
| description: str | None = None, |
| schedule: ScheduleArg = NOTSET, |
| schedule_interval: ScheduleIntervalArg = NOTSET, |
| timetable: Timetable | None = None, |
| start_date: datetime | None = None, |
| end_date: datetime | None = None, |
| full_filepath: str | None = None, |
| template_searchpath: str | Iterable[str] | None = None, |
| template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, |
| user_defined_macros: dict | None = None, |
| user_defined_filters: dict | None = None, |
| default_args: dict | None = None, |
| concurrency: int | None = None, |
| max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), |
| max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), |
| dagrun_timeout: timedelta | None = None, |
| sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, |
| default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), |
| orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), |
| catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), |
| on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, |
| on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, |
| doc_md: str | None = None, |
| params: collections.abc.MutableMapping | None = None, |
| access_control: dict | None = None, |
| is_paused_upon_creation: bool | None = None, |
| jinja_environment_kwargs: dict | None = None, |
| render_template_as_native_obj: bool = False, |
| tags: list[str] | None = None, |
| owner_links: dict[str, str] | None = None, |
| auto_register: bool = True, |
| fail_stop: bool = False, |
| ) -> Callable[[Callable], Callable[..., DAG]]: |
| """ |
| Python dag decorator. Wraps a function into an Airflow DAG. |
| Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. |
| |
| :param dag_args: Arguments for DAG object |
| :param dag_kwargs: Kwargs for DAG object. |
| """ |
| |
| def wrapper(f: Callable) -> Callable[..., DAG]: |
| @functools.wraps(f) |
| def factory(*args, **kwargs): |
| # Generate signature for decorated function and bind the arguments when called |
| # we do this to extract parameters, so we can annotate them on the DAG object. |
| # In addition, this fails if we are missing any args/kwargs with TypeError as expected. |
| f_sig = signature(f).bind(*args, **kwargs) |
| # Apply defaults to capture default values if set. |
| f_sig.apply_defaults() |
| |
| # Initialize DAG with bound arguments |
| with DAG( |
| dag_id or f.__name__, |
| description=description, |
| schedule_interval=schedule_interval, |
| timetable=timetable, |
| start_date=start_date, |
| end_date=end_date, |
| full_filepath=full_filepath, |
| template_searchpath=template_searchpath, |
| template_undefined=template_undefined, |
| user_defined_macros=user_defined_macros, |
| user_defined_filters=user_defined_filters, |
| default_args=default_args, |
| concurrency=concurrency, |
| max_active_tasks=max_active_tasks, |
| max_active_runs=max_active_runs, |
| dagrun_timeout=dagrun_timeout, |
| sla_miss_callback=sla_miss_callback, |
| default_view=default_view, |
| orientation=orientation, |
| catchup=catchup, |
| on_success_callback=on_success_callback, |
| on_failure_callback=on_failure_callback, |
| doc_md=doc_md, |
| params=params, |
| access_control=access_control, |
| is_paused_upon_creation=is_paused_upon_creation, |
| jinja_environment_kwargs=jinja_environment_kwargs, |
| render_template_as_native_obj=render_template_as_native_obj, |
| tags=tags, |
| schedule=schedule, |
| owner_links=owner_links, |
| auto_register=auto_register, |
| fail_stop=fail_stop, |
| ) as dag_obj: |
| # Set DAG documentation from function documentation if it exists and doc_md is not set. |
| if f.__doc__ and not dag_obj.doc_md: |
| dag_obj.doc_md = f.__doc__ |
| |
| # Generate DAGParam for each function arg/kwarg and replace it for calling the function. |
| # All args/kwargs for function will be DAGParam object and replaced on execution time. |
| f_kwargs = {} |
| for name, value in f_sig.arguments.items(): |
| f_kwargs[name] = dag_obj.param(name, value) |
| |
| # set file location to caller source path |
| back = sys._getframe().f_back |
| dag_obj.fileloc = back.f_code.co_filename if back else "" |
| |
| # Invoke function to create operators in the DAG scope. |
| f(**f_kwargs) |
| |
| # Return dag object such that it's accessible in Globals. |
| return dag_obj |
| |
| # Ensure that warnings from inside DAG() are emitted from the caller, not here |
| fixup_decorator_warning_stack(factory) |
| return factory |
| |
| return wrapper |
| |
| |
| STATICA_HACK = True |
| globals()["kcah_acitats"[::-1].upper()] = False |
| if STATICA_HACK: # pragma: no cover |
| |
| from airflow.models.serialized_dag import SerializedDagModel |
| |
| DagModel.serialized_dag = relationship(SerializedDagModel) |
| """:sphinx-autoapi-skip:""" |
| |
| |
| class DagContext: |
| """ |
| DAG context is used to keep the current DAG when DAG is used as ContextManager. |
| |
| You can use DAG as context: |
| |
| .. code-block:: python |
| |
| with DAG( |
| dag_id="example_dag", |
| default_args=default_args, |
| schedule="0 0 * * *", |
| dagrun_timeout=timedelta(minutes=60), |
| ) as dag: |
| ... |
| |
| If you do this the context stores the DAG and whenever new task is created, it will use |
| such stored DAG as the parent DAG. |
| |
| """ |
| |
| _context_managed_dags: Deque[DAG] = deque() |
| autoregistered_dags: set[tuple[DAG, ModuleType]] = set() |
| current_autoregister_module_name: str | None = None |
| |
| @classmethod |
| def push_context_managed_dag(cls, dag: DAG): |
| cls._context_managed_dags.appendleft(dag) |
| |
| @classmethod |
| def pop_context_managed_dag(cls) -> DAG | None: |
| dag = cls._context_managed_dags.popleft() |
| |
| # In a few cases around serialization we explicitly push None in to the stack |
| if cls.current_autoregister_module_name is not None and dag and dag.auto_register: |
| mod = sys.modules[cls.current_autoregister_module_name] |
| cls.autoregistered_dags.add((dag, mod)) |
| |
| return dag |
| |
| @classmethod |
| def get_current_dag(cls) -> DAG | None: |
| try: |
| return cls._context_managed_dags[0] |
| except IndexError: |
| return None |
| |
| |
| def _run_task(ti: TaskInstance, session): |
| """ |
| Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of |
| extra steps used in `task.run` to keep our local running as fast as possible |
| This function is only meant for the `dag.test` function as a helper function. |
| |
| Args: |
| ti: TaskInstance to run |
| """ |
| log.info("*****************************************************") |
| if ti.map_index > 0: |
| log.info("Running task %s index %d", ti.task_id, ti.map_index) |
| else: |
| log.info("Running task %s", ti.task_id) |
| try: |
| ti._run_raw_task(session=session) |
| session.flush() |
| log.info("%s ran successfully!", ti.task_id) |
| except AirflowSkipException: |
| log.info("Task Skipped, continuing") |
| log.info("*****************************************************") |
| |
| |
| def _get_or_create_dagrun( |
| dag: DAG, |
| conf: dict[Any, Any] | None, |
| start_date: datetime, |
| execution_date: datetime, |
| run_id: str, |
| session: Session, |
| ) -> DagRun: |
| """Create a DAG run, replacing an existing instance if needed to prevent collisions. |
| |
| This function is only meant to be used by :meth:`DAG.test` as a helper function. |
| |
| :param dag: DAG to be used to find run. |
| :param conf: Configuration to pass to newly created run. |
| :param start_date: Start date of new run. |
| :param execution_date: Logical date for finding an existing run. |
| :param run_id: Run ID for the new DAG run. |
| |
| :return: The newly created DAG run. |
| """ |
| log.info("dagrun id: %s", dag.dag_id) |
| dr: DagRun = ( |
| session.query(DagRun) |
| .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) |
| .first() |
| ) |
| if dr: |
| session.delete(dr) |
| session.commit() |
| dr = dag.create_dagrun( |
| state=DagRunState.RUNNING, |
| execution_date=execution_date, |
| run_id=run_id, |
| start_date=start_date or execution_date, |
| session=session, |
| conf=conf, |
| ) |
| log.info("created dagrun %s", dr) |
| return dr |