| # -*- coding: utf-8 -*- |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| from future.standard_library import install_aliases |
| |
| install_aliases() |
| from builtins import str |
| from builtins import object, bytes |
| import copy |
| from collections import namedtuple |
| from datetime import datetime, timedelta |
| import dill |
| import functools |
| import getpass |
| import imp |
| import importlib |
| import itertools |
| import zipfile |
| import jinja2 |
| import json |
| import logging |
| import os |
| import pickle |
| import re |
| import signal |
| import socket |
| import sys |
| import textwrap |
| import traceback |
| import warnings |
| import hashlib |
| from urllib.parse import urlparse |
| |
| from sqlalchemy import ( |
| Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType, |
| Index, Float, LargeBinary) |
| from sqlalchemy import func, or_, and_ |
| from sqlalchemy.ext.declarative import declarative_base, declared_attr |
| from sqlalchemy.dialects.mysql import LONGTEXT |
| from sqlalchemy.orm import reconstructor, relationship, synonym |
| |
| from croniter import croniter |
| import six |
| |
| from airflow import settings, utils |
| from airflow.executors import GetDefaultExecutor, LocalExecutor |
| from airflow import configuration |
| from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout |
| from airflow.dag.base_dag import BaseDag, BaseDagBag |
| from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep |
| from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep |
| from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep |
| from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep |
| |
| from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS |
| from airflow.utils.dates import cron_presets, date_range as utils_date_range |
| from airflow.utils.db import provide_session |
| from airflow.utils.decorators import apply_defaults |
| from airflow.utils.email import send_email |
| from airflow.utils.helpers import ( |
| as_tuple, is_container, is_in, validate_key, pprinttable) |
| from airflow.utils.operator_resources import Resources |
| from airflow.utils.state import State |
| from airflow.utils.timeout import timeout |
| from airflow.utils.trigger_rule import TriggerRule |
| from airflow.utils.log.logging_mixin import LoggingMixin |
| |
| Base = declarative_base() |
| ID_LEN = 250 |
| XCOM_RETURN_KEY = 'return_value' |
| |
| Stats = settings.Stats |
| |
| def get_fernet(): |
| """ |
| Deferred load of Fernet key. |
| |
| This function could fail either because Cryptography is not installed |
| or because the Fernet key is invalid. |
| |
| :return: Fernet object |
| :raises: AirflowException if there's a problem trying to load Fernet |
| """ |
| try: |
| from cryptography.fernet import Fernet |
| except: |
| raise AirflowException('Failed to import Fernet, it may not be installed') |
| try: |
| return Fernet(configuration.get('core', 'FERNET_KEY').encode('utf-8')) |
| except ValueError as ve: |
| raise AirflowException("Could not create Fernet object: {}".format(ve)) |
| |
| |
| if 'mysql' in settings.SQL_ALCHEMY_CONN: |
| LongText = LONGTEXT |
| else: |
| LongText = Text |
| |
| # Used by DAG context_managers |
| _CONTEXT_MANAGER_DAG = None |
| |
| |
| def clear_task_instances(tis, session, activate_dag_runs=True, dag=None): |
| """ |
| Clears a set of task instances, but makes sure the running ones |
| get killed. |
| """ |
| job_ids = [] |
| for ti in tis: |
| if ti.state == State.RUNNING: |
| if ti.job_id: |
| ti.state = State.SHUTDOWN |
| job_ids.append(ti.job_id) |
| else: |
| task_id = ti.task_id |
| if dag and dag.has_task(task_id): |
| task = dag.get_task(task_id) |
| task_retries = task.retries |
| ti.max_tries = ti.try_number + task_retries - 1 |
| else: |
| # Ignore errors when updating max_tries if dag is None or |
| # task not found in dag since database records could be |
| # outdated. We make max_tries the maximum value of its |
| # original max_tries or the current task try number. |
| ti.max_tries = max(ti.max_tries, ti.try_number - 1) |
| ti.state = State.NONE |
| session.merge(ti) |
| |
| if job_ids: |
| from airflow.jobs import BaseJob as BJ |
| for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all(): |
| job.state = State.SHUTDOWN |
| |
| if activate_dag_runs and tis: |
| drs = session.query(DagRun).filter( |
| DagRun.dag_id.in_({ti.dag_id for ti in tis}), |
| DagRun.execution_date.in_({ti.execution_date for ti in tis}), |
| ).all() |
| for dr in drs: |
| dr.state = State.RUNNING |
| dr.start_date = datetime.utcnow() |
| |
| |
| class DagBag(BaseDagBag, LoggingMixin): |
| """ |
| A dagbag is a collection of dags, parsed out of a folder tree and has high |
| level configuration settings, like what database to use as a backend and |
| what executor to use to fire off tasks. This makes it easier to run |
| distinct environments for say production and development, tests, or for |
| different teams or security profiles. What would have been system level |
| settings are now dagbag level so that one system can run multiple, |
| independent settings sets. |
| |
| :param dag_folder: the folder to scan to find DAGs |
| :type dag_folder: unicode |
| :param executor: the executor to use when executing task instances |
| in this DagBag |
| :param include_examples: whether to include the examples that ship |
| with airflow or not |
| :type include_examples: bool |
| """ |
| |
| def __init__( |
| self, |
| dag_folder=None, |
| executor=None, |
| include_examples=configuration.getboolean('core', 'LOAD_EXAMPLES')): |
| |
| # do not use default arg in signature, to fix import cycle on plugin load |
| if executor is None: |
| executor = GetDefaultExecutor() |
| dag_folder = dag_folder or settings.DAGS_FOLDER |
| self.log.info("Filling up the DagBag from %s", dag_folder) |
| self.dag_folder = dag_folder |
| self.dags = {} |
| # the file's last modified timestamp when we last read it |
| self.file_last_changed = {} |
| self.executor = executor |
| self.import_errors = {} |
| |
| if include_examples: |
| example_dag_folder = os.path.join( |
| os.path.dirname(__file__), |
| 'example_dags') |
| self.collect_dags(example_dag_folder) |
| self.collect_dags(dag_folder) |
| |
| def size(self): |
| """ |
| :return: the amount of dags contained in this dagbag |
| """ |
| return len(self.dags) |
| |
| def get_dag(self, dag_id): |
| """ |
| Gets the DAG out of the dictionary, and refreshes it if expired |
| """ |
| # If asking for a known subdag, we want to refresh the parent |
| root_dag_id = dag_id |
| if dag_id in self.dags: |
| dag = self.dags[dag_id] |
| if dag.is_subdag: |
| root_dag_id = dag.parent_dag.dag_id |
| |
| # If the dag corresponding to root_dag_id is absent or expired |
| orm_dag = DagModel.get_current(root_dag_id) |
| if orm_dag and ( |
| root_dag_id not in self.dags or |
| ( |
| orm_dag.last_expired and |
| dag.last_loaded < orm_dag.last_expired |
| ) |
| ): |
| # Reprocess source file |
| found_dags = self.process_file( |
| filepath=orm_dag.fileloc, only_if_updated=False) |
| |
| # If the source file no longer exports `dag_id`, delete it from self.dags |
| if found_dags and dag_id in [dag.dag_id for dag in found_dags]: |
| return self.dags[dag_id] |
| elif dag_id in self.dags: |
| del self.dags[dag_id] |
| return self.dags.get(dag_id) |
| |
| def process_file(self, filepath, only_if_updated=True, safe_mode=True): |
| """ |
| Given a path to a python module or zip file, this method imports |
| the module and look for dag objects within it. |
| """ |
| found_dags = [] |
| |
| # todo: raise exception? |
| if not os.path.isfile(filepath): |
| return found_dags |
| |
| try: |
| # This failed before in what may have been a git sync |
| # race condition |
| file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath)) |
| if only_if_updated \ |
| and filepath in self.file_last_changed \ |
| and file_last_changed_on_disk == self.file_last_changed[filepath]: |
| return found_dags |
| |
| except Exception as e: |
| self.log.exception(e) |
| return found_dags |
| |
| mods = [] |
| if not zipfile.is_zipfile(filepath): |
| if safe_mode and os.path.isfile(filepath): |
| with open(filepath, 'rb') as f: |
| content = f.read() |
| if not all([s in content for s in (b'DAG', b'airflow')]): |
| self.file_last_changed[filepath] = file_last_changed_on_disk |
| return found_dags |
| |
| self.log.debug("Importing %s", filepath) |
| org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) |
| mod_name = ('unusual_prefix_' + |
| hashlib.sha1(filepath.encode('utf-8')).hexdigest() + |
| '_' + org_mod_name) |
| |
| if mod_name in sys.modules: |
| del sys.modules[mod_name] |
| |
| with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")): |
| try: |
| m = imp.load_source(mod_name, filepath) |
| mods.append(m) |
| except Exception as e: |
| self.log.exception("Failed to import: %s", filepath) |
| self.import_errors[filepath] = str(e) |
| self.file_last_changed[filepath] = file_last_changed_on_disk |
| |
| else: |
| zip_file = zipfile.ZipFile(filepath) |
| for mod in zip_file.infolist(): |
| head, _ = os.path.split(mod.filename) |
| mod_name, ext = os.path.splitext(mod.filename) |
| if not head and (ext == '.py' or ext == '.pyc'): |
| if mod_name == '__init__': |
| self.log.warning("Found __init__.%s at root of %s", ext, filepath) |
| if safe_mode: |
| with zip_file.open(mod.filename) as zf: |
| self.log.debug("Reading %s from %s", mod.filename, filepath) |
| content = zf.read() |
| if not all([s in content for s in (b'DAG', b'airflow')]): |
| self.file_last_changed[filepath] = ( |
| file_last_changed_on_disk) |
| # todo: create ignore list |
| return found_dags |
| |
| if mod_name in sys.modules: |
| del sys.modules[mod_name] |
| |
| try: |
| sys.path.insert(0, filepath) |
| m = importlib.import_module(mod_name) |
| mods.append(m) |
| except Exception as e: |
| self.log.exception("Failed to import: %s", filepath) |
| self.import_errors[filepath] = str(e) |
| self.file_last_changed[filepath] = file_last_changed_on_disk |
| |
| for m in mods: |
| for dag in list(m.__dict__.values()): |
| if isinstance(dag, DAG): |
| if not dag.full_filepath: |
| dag.full_filepath = filepath |
| dag.is_subdag = False |
| self.bag_dag(dag, parent_dag=dag, root_dag=dag) |
| found_dags.append(dag) |
| found_dags += dag.subdags |
| |
| self.file_last_changed[filepath] = file_last_changed_on_disk |
| return found_dags |
| |
| @provide_session |
| def kill_zombies(self, session=None): |
| """ |
| Fails tasks that haven't had a heartbeat in too long |
| """ |
| from airflow.jobs import LocalTaskJob as LJ |
| self.log.info("Finding 'running' jobs without a recent heartbeat") |
| TI = TaskInstance |
| secs = configuration.getint('scheduler', 'scheduler_zombie_task_threshold') |
| limit_dttm = datetime.utcnow() - timedelta(seconds=secs) |
| self.log.info("Failing jobs without heartbeat after %s", limit_dttm) |
| |
| tis = ( |
| session.query(TI) |
| .join(LJ, TI.job_id == LJ.id) |
| .filter(TI.state == State.RUNNING) |
| .filter( |
| or_( |
| LJ.state != State.RUNNING, |
| LJ.latest_heartbeat < limit_dttm, |
| )) |
| .all() |
| ) |
| |
| for ti in tis: |
| if ti and ti.dag_id in self.dags: |
| dag = self.dags[ti.dag_id] |
| if ti.task_id in dag.task_ids: |
| task = dag.get_task(ti.task_id) |
| ti.task = task |
| ti.handle_failure("{} killed as zombie".format(str(ti))) |
| self.log.info('Marked zombie job %s as failed', ti) |
| Stats.incr('zombies_killed') |
| session.commit() |
| |
| def bag_dag(self, dag, parent_dag, root_dag): |
| """ |
| Adds the DAG into the bag, recurses into sub dags. |
| """ |
| self.dags[dag.dag_id] = dag |
| dag.resolve_template_files() |
| dag.last_loaded = datetime.utcnow() |
| |
| for task in dag.tasks: |
| settings.policy(task) |
| |
| for subdag in dag.subdags: |
| subdag.full_filepath = dag.full_filepath |
| subdag.parent_dag = dag |
| subdag.is_subdag = True |
| self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag) |
| self.log.debug('Loaded DAG {dag}'.format(**locals())) |
| |
| def collect_dags( |
| self, |
| dag_folder=None, |
| only_if_updated=True): |
| """ |
| Given a file path or a folder, this method looks for python modules, |
| imports them and adds them to the dagbag collection. |
| |
| Note that if a .airflowignore file is found while processing, |
| the directory, it will behaves much like a .gitignore does, |
| ignoring files that match any of the regex patterns specified |
| in the file. |
| """ |
| start_dttm = datetime.utcnow() |
| dag_folder = dag_folder or self.dag_folder |
| |
| # Used to store stats around DagBag processing |
| stats = [] |
| FileLoadStat = namedtuple( |
| 'FileLoadStat', "file duration dag_num task_num dags") |
| if os.path.isfile(dag_folder): |
| self.process_file(dag_folder, only_if_updated=only_if_updated) |
| elif os.path.isdir(dag_folder): |
| patterns = [] |
| for root, dirs, files in os.walk(dag_folder, followlinks=True): |
| ignore_file = [f for f in files if f == '.airflowignore'] |
| if ignore_file: |
| f = open(os.path.join(root, ignore_file[0]), 'r') |
| patterns += [p for p in f.read().split('\n') if p] |
| f.close() |
| for f in files: |
| try: |
| filepath = os.path.join(root, f) |
| if not os.path.isfile(filepath): |
| continue |
| mod_name, file_ext = os.path.splitext( |
| os.path.split(filepath)[-1]) |
| if file_ext != '.py' and not zipfile.is_zipfile(filepath): |
| continue |
| if not any( |
| [re.findall(p, filepath) for p in patterns]): |
| ts = datetime.utcnow() |
| found_dags = self.process_file( |
| filepath, only_if_updated=only_if_updated) |
| |
| td = datetime.utcnow() - ts |
| td = td.total_seconds() + ( |
| float(td.microseconds) / 1000000) |
| stats.append(FileLoadStat( |
| filepath.replace(dag_folder, ''), |
| td, |
| len(found_dags), |
| sum([len(dag.tasks) for dag in found_dags]), |
| str([dag.dag_id for dag in found_dags]), |
| )) |
| except Exception as e: |
| self.log.warning(e) |
| Stats.gauge( |
| 'collect_dags', (datetime.utcnow() - start_dttm).total_seconds(), 1) |
| Stats.gauge( |
| 'dagbag_size', len(self.dags), 1) |
| Stats.gauge( |
| 'dagbag_import_errors', len(self.import_errors), 1) |
| self.dagbag_stats = sorted( |
| stats, key=lambda x: x.duration, reverse=True) |
| |
| def dagbag_report(self): |
| """Prints a report around DagBag loading stats""" |
| report = textwrap.dedent("""\n |
| ------------------------------------------------------------------- |
| DagBag loading stats for {dag_folder} |
| ------------------------------------------------------------------- |
| Number of DAGs: {dag_num} |
| Total task number: {task_num} |
| DagBag parsing time: {duration} |
| {table} |
| """) |
| stats = self.dagbag_stats |
| return report.format( |
| dag_folder=self.dag_folder, |
| duration=sum([o.duration for o in stats]), |
| dag_num=sum([o.dag_num for o in stats]), |
| task_num=sum([o.dag_num for o in stats]), |
| table=pprinttable(stats), |
| ) |
| |
| def deactivate_inactive_dags(self): |
| active_dag_ids = [dag.dag_id for dag in list(self.dags.values())] |
| session = settings.Session() |
| for dag in session.query( |
| DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): |
| dag.is_active = False |
| session.merge(dag) |
| session.commit() |
| session.close() |
| |
| def paused_dags(self): |
| session = settings.Session() |
| dag_ids = [dp.dag_id for dp in session.query(DagModel).filter( |
| DagModel.is_paused.__eq__(True))] |
| session.commit() |
| session.close() |
| return dag_ids |
| |
| |
| class User(Base): |
| __tablename__ = "users" |
| |
| id = Column(Integer, primary_key=True) |
| username = Column(String(ID_LEN), unique=True) |
| email = Column(String(500)) |
| superuser = False |
| |
| def __repr__(self): |
| return self.username |
| |
| def get_id(self): |
| return str(self.id) |
| |
| def is_superuser(self): |
| return self.superuser |
| |
| |
| class Connection(Base, LoggingMixin): |
| """ |
| Placeholder to store information about different database instances |
| connection information. The idea here is that scripts use references to |
| database instances (conn_id) instead of hard coding hostname, logins and |
| passwords when using operators or hooks. |
| """ |
| __tablename__ = "connection" |
| |
| id = Column(Integer(), primary_key=True) |
| conn_id = Column(String(ID_LEN)) |
| conn_type = Column(String(500)) |
| host = Column(String(500)) |
| schema = Column(String(500)) |
| login = Column(String(500)) |
| _password = Column('password', String(5000)) |
| port = Column(Integer()) |
| is_encrypted = Column(Boolean, unique=False, default=False) |
| is_extra_encrypted = Column(Boolean, unique=False, default=False) |
| _extra = Column('extra', String(5000)) |
| |
| _types = [ |
| ('docker', 'Docker Registry',), |
| ('fs', 'File (path)'), |
| ('ftp', 'FTP',), |
| ('google_cloud_platform', 'Google Cloud Platform'), |
| ('hdfs', 'HDFS',), |
| ('http', 'HTTP',), |
| ('hive_cli', 'Hive Client Wrapper',), |
| ('hive_metastore', 'Hive Metastore Thrift',), |
| ('hiveserver2', 'Hive Server 2 Thrift',), |
| ('jdbc', 'Jdbc Connection',), |
| ('mysql', 'MySQL',), |
| ('postgres', 'Postgres',), |
| ('oracle', 'Oracle',), |
| ('vertica', 'Vertica',), |
| ('presto', 'Presto',), |
| ('s3', 'S3',), |
| ('samba', 'Samba',), |
| ('sqlite', 'Sqlite',), |
| ('ssh', 'SSH',), |
| ('cloudant', 'IBM Cloudant',), |
| ('mssql', 'Microsoft SQL Server'), |
| ('mesos_framework-id', 'Mesos Framework ID'), |
| ('jira', 'JIRA',), |
| ('redis', 'Redis',), |
| ('wasb', 'Azure Blob Storage'), |
| ('databricks', 'Databricks',), |
| ('aws', 'Amazon Web Services',), |
| ('emr', 'Elastic MapReduce',), |
| ] |
| |
| def __init__( |
| self, conn_id=None, conn_type=None, |
| host=None, login=None, password=None, |
| schema=None, port=None, extra=None, |
| uri=None): |
| self.conn_id = conn_id |
| if uri: |
| self.parse_from_uri(uri) |
| else: |
| self.conn_type = conn_type |
| self.host = host |
| self.login = login |
| self.password = password |
| self.schema = schema |
| self.port = port |
| self.extra = extra |
| |
| def parse_from_uri(self, uri): |
| temp_uri = urlparse(uri) |
| hostname = temp_uri.hostname or '' |
| if '%2f' in hostname: |
| hostname = hostname.replace('%2f', '/').replace('%2F', '/') |
| conn_type = temp_uri.scheme |
| if conn_type == 'postgresql': |
| conn_type = 'postgres' |
| self.conn_type = conn_type |
| self.host = hostname |
| self.schema = temp_uri.path[1:] |
| self.login = temp_uri.username |
| self.password = temp_uri.password |
| self.port = temp_uri.port |
| |
| def get_password(self): |
| if self._password and self.is_encrypted: |
| try: |
| fernet = get_fernet() |
| except: |
| raise AirflowException( |
| "Can't decrypt encrypted password for login={}, \ |
| FERNET_KEY configuration is missing".format(self.login)) |
| return fernet.decrypt(bytes(self._password, 'utf-8')).decode() |
| else: |
| return self._password |
| |
| def set_password(self, value): |
| if value: |
| try: |
| fernet = get_fernet() |
| self._password = fernet.encrypt(bytes(value, 'utf-8')).decode() |
| self.is_encrypted = True |
| except AirflowException: |
| self.log.exception("Failed to load fernet while encrypting value, " |
| "using non-encrypted value.") |
| self._password = value |
| self.is_encrypted = False |
| |
| @declared_attr |
| def password(cls): |
| return synonym('_password', |
| descriptor=property(cls.get_password, cls.set_password)) |
| |
| def get_extra(self): |
| if self._extra and self.is_extra_encrypted: |
| try: |
| fernet = get_fernet() |
| except: |
| raise AirflowException( |
| "Can't decrypt `extra` params for login={},\ |
| FERNET_KEY configuration is missing".format(self.login)) |
| return fernet.decrypt(bytes(self._extra, 'utf-8')).decode() |
| else: |
| return self._extra |
| |
| def set_extra(self, value): |
| if value: |
| try: |
| fernet = get_fernet() |
| self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode() |
| self.is_extra_encrypted = True |
| except AirflowException: |
| self.log.exception("Failed to load fernet while encrypting value, " |
| "using non-encrypted value.") |
| self._extra = value |
| self.is_extra_encrypted = False |
| |
| @declared_attr |
| def extra(cls): |
| return synonym('_extra', |
| descriptor=property(cls.get_extra, cls.set_extra)) |
| |
| def get_hook(self): |
| try: |
| if self.conn_type == 'mysql': |
| from airflow.hooks.mysql_hook import MySqlHook |
| return MySqlHook(mysql_conn_id=self.conn_id) |
| elif self.conn_type == 'google_cloud_platform': |
| from airflow.contrib.hooks.bigquery_hook import BigQueryHook |
| return BigQueryHook(bigquery_conn_id=self.conn_id) |
| elif self.conn_type == 'postgres': |
| from airflow.hooks.postgres_hook import PostgresHook |
| return PostgresHook(postgres_conn_id=self.conn_id) |
| elif self.conn_type == 'hive_cli': |
| from airflow.hooks.hive_hooks import HiveCliHook |
| return HiveCliHook(hive_cli_conn_id=self.conn_id) |
| elif self.conn_type == 'presto': |
| from airflow.hooks.presto_hook import PrestoHook |
| return PrestoHook(presto_conn_id=self.conn_id) |
| elif self.conn_type == 'hiveserver2': |
| from airflow.hooks.hive_hooks import HiveServer2Hook |
| return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) |
| elif self.conn_type == 'sqlite': |
| from airflow.hooks.sqlite_hook import SqliteHook |
| return SqliteHook(sqlite_conn_id=self.conn_id) |
| elif self.conn_type == 'jdbc': |
| from airflow.hooks.jdbc_hook import JdbcHook |
| return JdbcHook(jdbc_conn_id=self.conn_id) |
| elif self.conn_type == 'mssql': |
| from airflow.hooks.mssql_hook import MsSqlHook |
| return MsSqlHook(mssql_conn_id=self.conn_id) |
| elif self.conn_type == 'oracle': |
| from airflow.hooks.oracle_hook import OracleHook |
| return OracleHook(oracle_conn_id=self.conn_id) |
| elif self.conn_type == 'vertica': |
| from airflow.contrib.hooks.vertica_hook import VerticaHook |
| return VerticaHook(vertica_conn_id=self.conn_id) |
| elif self.conn_type == 'cloudant': |
| from airflow.contrib.hooks.cloudant_hook import CloudantHook |
| return CloudantHook(cloudant_conn_id=self.conn_id) |
| elif self.conn_type == 'jira': |
| from airflow.contrib.hooks.jira_hook import JiraHook |
| return JiraHook(jira_conn_id=self.conn_id) |
| elif self.conn_type == 'redis': |
| from airflow.contrib.hooks.redis_hook import RedisHook |
| return RedisHook(redis_conn_id=self.conn_id) |
| elif self.conn_type == 'wasb': |
| from airflow.contrib.hooks.wasb_hook import WasbHook |
| return WasbHook(wasb_conn_id=self.conn_id) |
| elif self.conn_type == 'docker': |
| from airflow.hooks.docker_hook import DockerHook |
| return DockerHook(docker_conn_id=self.conn_id) |
| except: |
| pass |
| |
| def __repr__(self): |
| return self.conn_id |
| |
| @property |
| def extra_dejson(self): |
| """Returns the extra property by deserializing json.""" |
| obj = {} |
| if self.extra: |
| try: |
| obj = json.loads(self.extra) |
| except Exception as e: |
| self.log.exception(e) |
| self.log.error("Failed parsing the json for conn_id %s", self.conn_id) |
| |
| return obj |
| |
| |
| class DagPickle(Base): |
| """ |
| Dags can originate from different places (user repos, master repo, ...) |
| and also get executed in different places (different executors). This |
| object represents a version of a DAG and becomes a source of truth for |
| a BackfillJob execution. A pickle is a native python serialized object, |
| and in this case gets stored in the database for the duration of the job. |
| |
| The executors pick up the DagPickle id and read the dag definition from |
| the database. |
| """ |
| id = Column(Integer, primary_key=True) |
| pickle = Column(PickleType(pickler=dill)) |
| created_dttm = Column(DateTime, default=func.now()) |
| pickle_hash = Column(Text) |
| |
| __tablename__ = "dag_pickle" |
| |
| def __init__(self, dag): |
| self.dag_id = dag.dag_id |
| if hasattr(dag, 'template_env'): |
| dag.template_env = None |
| self.pickle_hash = hash(dag) |
| self.pickle = dag |
| |
| |
| class TaskInstance(Base, LoggingMixin): |
| """ |
| Task instances store the state of a task instance. This table is the |
| authority and single source of truth around what tasks have run and the |
| state they are in. |
| |
| The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or |
| dag model deliberately to have more control over transactions. |
| |
| Database transactions on this table should insure double triggers and |
| any confusion around what task instances are or aren't ready to run |
| even while multiple schedulers may be firing task instances. |
| """ |
| |
| __tablename__ = "task_instance" |
| |
| task_id = Column(String(ID_LEN), primary_key=True) |
| dag_id = Column(String(ID_LEN), primary_key=True) |
| execution_date = Column(DateTime, primary_key=True) |
| start_date = Column(DateTime) |
| end_date = Column(DateTime) |
| duration = Column(Float) |
| state = Column(String(20)) |
| _try_number = Column('try_number', Integer, default=0) |
| max_tries = Column(Integer) |
| hostname = Column(String(1000)) |
| unixname = Column(String(1000)) |
| job_id = Column(Integer) |
| pool = Column(String(50)) |
| queue = Column(String(50)) |
| priority_weight = Column(Integer) |
| operator = Column(String(1000)) |
| queued_dttm = Column(DateTime) |
| pid = Column(Integer) |
| |
| __table_args__ = ( |
| Index('ti_dag_state', dag_id, state), |
| Index('ti_state', state), |
| Index('ti_state_lkp', dag_id, task_id, execution_date, state), |
| Index('ti_pool', pool, state, priority_weight), |
| ) |
| |
| def __init__(self, task, execution_date, state=None): |
| self.dag_id = task.dag_id |
| self.task_id = task.task_id |
| self.execution_date = execution_date |
| self.task = task |
| self.queue = task.queue |
| self.pool = task.pool |
| self.priority_weight = task.priority_weight_total |
| self.try_number = 0 |
| self.max_tries = self.task.retries |
| self.unixname = getpass.getuser() |
| self.run_as_user = task.run_as_user |
| if state: |
| self.state = state |
| self.hostname = '' |
| self.init_on_load() |
| self._log = logging.getLogger("airflow.task") |
| |
| @reconstructor |
| def init_on_load(self): |
| """ Initialize the attributes that aren't stored in the DB. """ |
| self.test_mode = False # can be changed when calling 'run' |
| |
| @property |
| def try_number(self): |
| """ |
| Return the try number that this task number will be when it is acutally |
| run. |
| |
| If the TI is currently running, this will match the column in the |
| databse, in all othercases this will be incremenetd |
| """ |
| # This is designed so that task logs end up in the right file. |
| if self.state == State.RUNNING: |
| return self._try_number |
| return self._try_number + 1 |
| |
| @try_number.setter |
| def try_number(self, value): |
| self._try_number = value |
| |
| @property |
| def next_try_number(self): |
| return self._try_number + 1 |
| |
| def command( |
| self, |
| mark_success=False, |
| ignore_all_deps=False, |
| ignore_depends_on_past=False, |
| ignore_task_deps=False, |
| ignore_ti_state=False, |
| local=False, |
| pickle_id=None, |
| raw=False, |
| job_id=None, |
| pool=None, |
| cfg_path=None): |
| """ |
| Returns a command that can be executed anywhere where airflow is |
| installed. This command is part of the message sent to executors by |
| the orchestrator. |
| """ |
| return " ".join(self.command_as_list( |
| mark_success=mark_success, |
| ignore_all_deps=ignore_all_deps, |
| ignore_depends_on_past=ignore_depends_on_past, |
| ignore_task_deps=ignore_task_deps, |
| ignore_ti_state=ignore_ti_state, |
| local=local, |
| pickle_id=pickle_id, |
| raw=raw, |
| job_id=job_id, |
| pool=pool, |
| cfg_path=cfg_path)) |
| |
| def command_as_list( |
| self, |
| mark_success=False, |
| ignore_all_deps=False, |
| ignore_task_deps=False, |
| ignore_depends_on_past=False, |
| ignore_ti_state=False, |
| local=False, |
| pickle_id=None, |
| raw=False, |
| job_id=None, |
| pool=None, |
| cfg_path=None): |
| """ |
| Returns a command that can be executed anywhere where airflow is |
| installed. This command is part of the message sent to executors by |
| the orchestrator. |
| """ |
| dag = self.task.dag |
| |
| should_pass_filepath = not pickle_id and dag |
| if should_pass_filepath and dag.full_filepath != dag.filepath: |
| path = "DAGS_FOLDER/{}".format(dag.filepath) |
| elif should_pass_filepath and dag.full_filepath: |
| path = dag.full_filepath |
| else: |
| path = None |
| |
| return TaskInstance.generate_command( |
| self.dag_id, |
| self.task_id, |
| self.execution_date, |
| mark_success=mark_success, |
| ignore_all_deps=ignore_all_deps, |
| ignore_task_deps=ignore_task_deps, |
| ignore_depends_on_past=ignore_depends_on_past, |
| ignore_ti_state=ignore_ti_state, |
| local=local, |
| pickle_id=pickle_id, |
| file_path=path, |
| raw=raw, |
| job_id=job_id, |
| pool=pool, |
| cfg_path=cfg_path) |
| |
| @staticmethod |
| def generate_command(dag_id, |
| task_id, |
| execution_date, |
| mark_success=False, |
| ignore_all_deps=False, |
| ignore_depends_on_past=False, |
| ignore_task_deps=False, |
| ignore_ti_state=False, |
| local=False, |
| pickle_id=None, |
| file_path=None, |
| raw=False, |
| job_id=None, |
| pool=None, |
| cfg_path=None |
| ): |
| """ |
| Generates the shell command required to execute this task instance. |
| |
| :param dag_id: DAG ID |
| :type dag_id: unicode |
| :param task_id: Task ID |
| :type task_id: unicode |
| :param execution_date: Execution date for the task |
| :type execution_date: datetime |
| :param mark_success: Whether to mark the task as successful |
| :type mark_success: bool |
| :param ignore_all_deps: Ignore all ignorable dependencies. |
| Overrides the other ignore_* parameters. |
| :type ignore_all_deps: boolean |
| :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs |
| (e.g. for Backfills) |
| :type ignore_depends_on_past: boolean |
| :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past |
| and trigger rule |
| :type ignore_task_deps: boolean |
| :param ignore_ti_state: Ignore the task instance's previous failure/success |
| :type ignore_ti_state: boolean |
| :param local: Whether to run the task locally |
| :type local: bool |
| :param pickle_id: If the DAG was serialized to the DB, the ID |
| associated with the pickled DAG |
| :type pickle_id: unicode |
| :param file_path: path to the file containing the DAG definition |
| :param raw: raw mode (needs more details) |
| :param job_id: job ID (needs more details) |
| :param pool: the Airflow pool that the task should run in |
| :type pool: unicode |
| :return: shell command that can be used to run the task instance |
| """ |
| iso = execution_date.isoformat() |
| cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)] |
| cmd.extend(["--mark_success"]) if mark_success else None |
| cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None |
| cmd.extend(["--job_id", str(job_id)]) if job_id else None |
| cmd.extend(["-A"]) if ignore_all_deps else None |
| cmd.extend(["-i"]) if ignore_task_deps else None |
| cmd.extend(["-I"]) if ignore_depends_on_past else None |
| cmd.extend(["--force"]) if ignore_ti_state else None |
| cmd.extend(["--local"]) if local else None |
| cmd.extend(["--pool", pool]) if pool else None |
| cmd.extend(["--raw"]) if raw else None |
| cmd.extend(["-sd", file_path]) if file_path else None |
| cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None |
| return cmd |
| |
| @property |
| def log_filepath(self): |
| iso = self.execution_date.isoformat() |
| log = os.path.expanduser(configuration.get('core', 'BASE_LOG_FOLDER')) |
| return ( |
| "{log}/{self.dag_id}/{self.task_id}/{iso}.log".format(**locals())) |
| |
| @property |
| def log_url(self): |
| iso = self.execution_date.isoformat() |
| BASE_URL = configuration.get('webserver', 'BASE_URL') |
| return BASE_URL + ( |
| "/admin/airflow/log" |
| "?dag_id={self.dag_id}" |
| "&task_id={self.task_id}" |
| "&execution_date={iso}" |
| ).format(**locals()) |
| |
| @property |
| def mark_success_url(self): |
| iso = self.execution_date.isoformat() |
| BASE_URL = configuration.get('webserver', 'BASE_URL') |
| return BASE_URL + ( |
| "/admin/airflow/action" |
| "?action=success" |
| "&task_id={self.task_id}" |
| "&dag_id={self.dag_id}" |
| "&execution_date={iso}" |
| "&upstream=false" |
| "&downstream=false" |
| ).format(**locals()) |
| |
| @provide_session |
| def current_state(self, session=None): |
| """ |
| Get the very latest state from the database, if a session is passed, |
| we use and looking up the state becomes part of the session, otherwise |
| a new session is used. |
| """ |
| TI = TaskInstance |
| ti = session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.task_id == self.task_id, |
| TI.execution_date == self.execution_date, |
| ).all() |
| if ti: |
| state = ti[0].state |
| else: |
| state = None |
| return state |
| |
| @provide_session |
| def error(self, session=None): |
| """ |
| Forces the task instance's state to FAILED in the database. |
| """ |
| self.log.error("Recording the task instance as FAILED") |
| self.state = State.FAILED |
| session.merge(self) |
| session.commit() |
| |
| @provide_session |
| def refresh_from_db(self, session=None, lock_for_update=False): |
| """ |
| Refreshes the task instance from the database based on the primary key |
| |
| :param lock_for_update: if True, indicates that the database should |
| lock the TaskInstance (issuing a FOR UPDATE clause) until the |
| session is committed. |
| """ |
| TI = TaskInstance |
| |
| qry = session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.task_id == self.task_id, |
| TI.execution_date == self.execution_date) |
| |
| if lock_for_update: |
| ti = qry.with_for_update().first() |
| else: |
| ti = qry.first() |
| if ti: |
| self.state = ti.state |
| self.start_date = ti.start_date |
| self.end_date = ti.end_date |
| # Get the raw value of try_number column, don't read through the |
| # accessor here otherwise it will be incremeneted by one already. |
| self.try_number = ti._try_number |
| self.max_tries = ti.max_tries |
| self.hostname = ti.hostname |
| self.pid = ti.pid |
| else: |
| self.state = None |
| |
| @provide_session |
| def clear_xcom_data(self, session=None): |
| """ |
| Clears all XCom data from the database for the task instance |
| """ |
| session.query(XCom).filter( |
| XCom.dag_id == self.dag_id, |
| XCom.task_id == self.task_id, |
| XCom.execution_date == self.execution_date |
| ).delete() |
| session.commit() |
| |
| @property |
| def key(self): |
| """ |
| Returns a tuple that identifies the task instance uniquely |
| """ |
| return self.dag_id, self.task_id, self.execution_date |
| |
| def set_state(self, state, session): |
| self.state = state |
| self.start_date = datetime.utcnow() |
| self.end_date = datetime.utcnow() |
| session.merge(self) |
| session.commit() |
| |
| @property |
| def is_premature(self): |
| """ |
| Returns whether a task is in UP_FOR_RETRY state and its retry interval |
| has elapsed. |
| """ |
| # is the task still in the retry waiting period? |
| return self.state == State.UP_FOR_RETRY and not self.ready_for_retry() |
| |
| @provide_session |
| def are_dependents_done(self, session=None): |
| """ |
| Checks whether the dependents of this task instance have all succeeded. |
| This is meant to be used by wait_for_downstream. |
| |
| This is useful when you do not want to start processing the next |
| schedule of a task until the dependents are done. For instance, |
| if the task DROPs and recreates a table. |
| """ |
| task = self.task |
| |
| if not task.downstream_task_ids: |
| return True |
| |
| ti = session.query(func.count(TaskInstance.task_id)).filter( |
| TaskInstance.dag_id == self.dag_id, |
| TaskInstance.task_id.in_(task.downstream_task_ids), |
| TaskInstance.execution_date == self.execution_date, |
| TaskInstance.state == State.SUCCESS, |
| ) |
| count = ti[0][0] |
| return count == len(task.downstream_task_ids) |
| |
| @property |
| @provide_session |
| def previous_ti(self, session=None): |
| """ The task instance for the task that ran before this task instance """ |
| |
| dag = self.task.dag |
| if dag: |
| dr = self.get_dagrun(session=session) |
| |
| # LEGACY: most likely running from unit tests |
| if not dr: |
| # Means that this TI is NOT being run from a DR, but from a catchup |
| previous_scheduled_date = dag.previous_schedule(self.execution_date) |
| if not previous_scheduled_date: |
| return None |
| |
| return TaskInstance(task=self.task, |
| execution_date=previous_scheduled_date) |
| |
| dr.dag = dag |
| if dag.catchup: |
| last_dagrun = dr.get_previous_scheduled_dagrun(session=session) |
| else: |
| last_dagrun = dr.get_previous_dagrun(session=session) |
| |
| if last_dagrun: |
| return last_dagrun.get_task_instance(self.task_id, session=session) |
| |
| return None |
| |
| @provide_session |
| def are_dependencies_met( |
| self, |
| dep_context=None, |
| session=None, |
| verbose=False): |
| """ |
| Returns whether or not all the conditions are met for this task instance to be run |
| given the context for the dependencies (e.g. a task instance being force run from |
| the UI will ignore some dependencies). |
| |
| :param dep_context: The execution context that determines the dependencies that |
| should be evaluated. |
| :type dep_context: DepContext |
| :param session: database session |
| :type session: Session |
| :param verbose: whether or not to print details on failed dependencies |
| :type verbose: boolean |
| """ |
| dep_context = dep_context or DepContext() |
| failed = False |
| for dep_status in self.get_failed_dep_statuses( |
| dep_context=dep_context, |
| session=session): |
| failed = True |
| if verbose: |
| self.log.info( |
| "Dependencies not met for %s, dependency '%s' FAILED: %s", |
| self, dep_status.dep_name, dep_status.reason |
| ) |
| |
| if failed: |
| return False |
| |
| if verbose: |
| self.log.info("Dependencies all met for %s", self) |
| |
| return True |
| |
| @provide_session |
| def get_failed_dep_statuses( |
| self, |
| dep_context=None, |
| session=None): |
| dep_context = dep_context or DepContext() |
| for dep in dep_context.deps | self.task.deps: |
| for dep_status in dep.get_dep_statuses( |
| self, |
| session, |
| dep_context): |
| |
| self.log.debug( |
| "%s dependency '%s' PASSED: %s, %s", |
| self, dep_status.dep_name, dep_status.passed, dep_status.reason |
| ) |
| |
| if not dep_status.passed: |
| yield dep_status |
| |
| def __repr__(self): |
| return ( |
| "<TaskInstance: {ti.dag_id}.{ti.task_id} " |
| "{ti.execution_date} [{ti.state}]>" |
| ).format(ti=self) |
| |
| def next_retry_datetime(self): |
| """ |
| Get datetime of the next retry if the task instance fails. For exponential |
| backoff, retry_delay is used as base and will be converted to seconds. |
| """ |
| delay = self.task.retry_delay |
| if self.task.retry_exponential_backoff: |
| min_backoff = int(delay.total_seconds() * (2 ** (self.try_number - 2))) |
| # deterministic per task instance |
| hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id, self.task_id, |
| self.execution_date, self.try_number).encode('utf-8')).hexdigest(), 16) |
| # between 0.5 * delay * (2^retry_number) and 1.0 * delay * (2^retry_number) |
| modded_hash = min_backoff + hash % min_backoff |
| # timedelta has a maximum representable value. The exponentiation |
| # here means this value can be exceeded after a certain number |
| # of tries (around 50 if the initial delay is 1s, even fewer if |
| # the delay is larger). Cap the value here before creating a |
| # timedelta object so the operation doesn't fail. |
| delay_backoff_in_seconds = min( |
| modded_hash, |
| timedelta.max.total_seconds() - 1 |
| ) |
| delay = timedelta(seconds=delay_backoff_in_seconds) |
| if self.task.max_retry_delay: |
| delay = min(self.task.max_retry_delay, delay) |
| return self.end_date + delay |
| |
| def ready_for_retry(self): |
| """ |
| Checks on whether the task instance is in the right state and timeframe |
| to be retried. |
| """ |
| return (self.state == State.UP_FOR_RETRY and |
| self.next_retry_datetime() < datetime.utcnow()) |
| |
| @provide_session |
| def pool_full(self, session): |
| """ |
| Returns a boolean as to whether the slot pool has room for this |
| task to run |
| """ |
| if not self.task.pool: |
| return False |
| |
| pool = ( |
| session |
| .query(Pool) |
| .filter(Pool.pool == self.task.pool) |
| .first() |
| ) |
| if not pool: |
| return False |
| open_slots = pool.open_slots(session=session) |
| |
| return open_slots <= 0 |
| |
| @provide_session |
| def get_dagrun(self, session): |
| """ |
| Returns the DagRun for this TaskInstance |
| |
| :param session: |
| :return: DagRun |
| """ |
| dr = session.query(DagRun).filter( |
| DagRun.dag_id == self.dag_id, |
| DagRun.execution_date == self.execution_date |
| ).first() |
| |
| return dr |
| |
| @provide_session |
| def _check_and_change_state_before_execution( |
| self, |
| verbose=True, |
| ignore_all_deps=False, |
| ignore_depends_on_past=False, |
| ignore_task_deps=False, |
| ignore_ti_state=False, |
| mark_success=False, |
| test_mode=False, |
| job_id=None, |
| pool=None, |
| session=None): |
| """ |
| Checks dependencies and then sets state to RUNNING if they are met. Returns |
| True if and only if state is set to RUNNING, which implies that task should be |
| executed, in preparation for _run_raw_task |
| |
| :param verbose: whether to turn on more verbose logging |
| :type verbose: boolean |
| :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs |
| :type ignore_all_deps: boolean |
| :param ignore_depends_on_past: Ignore depends_on_past DAG attribute |
| :type ignore_depends_on_past: boolean |
| :param ignore_task_deps: Don't check the dependencies of this TI's task |
| :type ignore_task_deps: boolean |
| :param ignore_ti_state: Disregards previous task instance state |
| :type ignore_ti_state: boolean |
| :param mark_success: Don't run the task, mark its state as success |
| :type mark_success: boolean |
| :param test_mode: Doesn't record success or failure in the DB |
| :type test_mode: boolean |
| :param pool: specifies the pool to use to run the task instance |
| :type pool: str |
| :return: whether the state was changed to running or not |
| :rtype: bool |
| """ |
| task = self.task |
| self.pool = pool or task.pool |
| self.test_mode = test_mode |
| self.refresh_from_db(session=session, lock_for_update=True) |
| self.job_id = job_id |
| self.hostname = socket.getfqdn() |
| self.operator = task.__class__.__name__ |
| |
| if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: |
| Stats.incr('previously_succeeded', 1, 1) |
| |
| queue_dep_context = DepContext( |
| deps=QUEUE_DEPS, |
| ignore_all_deps=ignore_all_deps, |
| ignore_ti_state=ignore_ti_state, |
| ignore_depends_on_past=ignore_depends_on_past, |
| ignore_task_deps=ignore_task_deps) |
| if not self.are_dependencies_met( |
| dep_context=queue_dep_context, |
| session=session, |
| verbose=True): |
| session.commit() |
| return False |
| |
| #TODO: Logging needs cleanup, not clear what is being printed |
| hr = "\n" + ("-" * 80) + "\n" # Line break |
| |
| # For reporting purposes, we report based on 1-indexed, |
| # not 0-indexed lists (i.e. Attempt 1 instead of |
| # Attempt 0 for the first attempt). |
| msg = "Starting attempt {attempt} of {total}".format( |
| attempt=self.try_number, |
| total=self.max_tries + 1) |
| self.start_date = datetime.utcnow() |
| |
| dep_context = DepContext( |
| deps=RUN_DEPS - QUEUE_DEPS, |
| ignore_all_deps=ignore_all_deps, |
| ignore_depends_on_past=ignore_depends_on_past, |
| ignore_task_deps=ignore_task_deps, |
| ignore_ti_state=ignore_ti_state) |
| runnable = self.are_dependencies_met( |
| dep_context=dep_context, |
| session=session, |
| verbose=True) |
| |
| if not runnable and not mark_success: |
| # FIXME: we might have hit concurrency limits, which means we probably |
| # have been running prematurely. This should be handled in the |
| # scheduling mechanism. |
| self.state = State.NONE |
| msg = ("FIXME: Rescheduling due to concurrency limits reached at task " |
| "runtime. Attempt {attempt} of {total}. State set to NONE.").format( |
| attempt=self.try_number, |
| total=self.max_tries + 1) |
| self.log.warning(hr + msg + hr) |
| |
| self.queued_dttm = datetime.utcnow() |
| self.log.info("Queuing into pool %s", self.pool) |
| session.merge(self) |
| session.commit() |
| return False |
| |
| # Another worker might have started running this task instance while |
| # the current worker process was blocked on refresh_from_db |
| if self.state == State.RUNNING: |
| msg = "Task Instance already running {}".format(self) |
| self.log.warning(msg) |
| session.commit() |
| return False |
| |
| # print status message |
| self.log.info(hr + msg + hr) |
| self._try_number += 1 |
| |
| if not test_mode: |
| session.add(Log(State.RUNNING, self)) |
| self.state = State.RUNNING |
| self.pid = os.getpid() |
| self.end_date = None |
| if not test_mode: |
| session.merge(self) |
| session.commit() |
| |
| # Closing all pooled connections to prevent |
| # "max number of connections reached" |
| settings.engine.dispose() |
| if verbose: |
| if mark_success: |
| msg = "Marking success for {} on {}".format(self.task, self.execution_date) |
| self.log.info(msg) |
| else: |
| msg = "Executing {} on {}".format(self.task, self.execution_date) |
| self.log.info(msg) |
| return True |
| |
| @provide_session |
| def _run_raw_task( |
| self, |
| mark_success=False, |
| test_mode=False, |
| job_id=None, |
| pool=None, |
| session=None): |
| """ |
| Immediately runs the task (without checking or changing db state |
| before execution) and then sets the appropriate final state after |
| completion and runs any post-execute callbacks. Meant to be called |
| only after another function changes the state to running. |
| |
| :param mark_success: Don't run the task, mark its state as success |
| :type mark_success: boolean |
| :param test_mode: Doesn't record success or failure in the DB |
| :type test_mode: boolean |
| :param pool: specifies the pool to use to run the task instance |
| :type pool: str |
| """ |
| task = self.task |
| self.pool = pool or task.pool |
| self.test_mode = test_mode |
| self.refresh_from_db(session=session) |
| self.job_id = job_id |
| self.hostname = socket.getfqdn() |
| self.operator = task.__class__.__name__ |
| |
| context = {} |
| try: |
| if not mark_success: |
| context = self.get_template_context() |
| |
| task_copy = copy.copy(task) |
| self.task = task_copy |
| |
| def signal_handler(signum, frame): |
| """Setting kill signal handler""" |
| self.log.error("Killing subprocess") |
| task_copy.on_kill() |
| raise AirflowException("Task received SIGTERM signal") |
| signal.signal(signal.SIGTERM, signal_handler) |
| |
| # Don't clear Xcom until the task is certain to execute |
| self.clear_xcom_data() |
| |
| self.render_templates() |
| task_copy.pre_execute(context=context) |
| |
| # If a timeout is specified for the task, make it fail |
| # if it goes beyond |
| result = None |
| if task_copy.execution_timeout: |
| try: |
| with timeout(int( |
| task_copy.execution_timeout.total_seconds())): |
| result = task_copy.execute(context=context) |
| except AirflowTaskTimeout: |
| task_copy.on_kill() |
| raise |
| else: |
| result = task_copy.execute(context=context) |
| |
| # If the task returns a result, push an XCom containing it |
| if result is not None: |
| self.xcom_push(key=XCOM_RETURN_KEY, value=result) |
| |
| # TODO remove deprecated behavior in Airflow 2.0 |
| try: |
| task_copy.post_execute(context=context, result=result) |
| except TypeError as e: |
| if 'unexpected keyword argument' in str(e): |
| warnings.warn( |
| 'BaseOperator.post_execute() now takes two ' |
| 'arguments, `context` and `result`, but "{}" only ' |
| 'expected one. This behavior is deprecated and ' |
| 'will be removed in a future version of ' |
| 'Airflow.'.format(self.task_id), |
| category=DeprecationWarning) |
| task_copy.post_execute(context=context) |
| else: |
| raise |
| |
| Stats.incr('operator_successes_{}'.format( |
| self.task.__class__.__name__), 1, 1) |
| Stats.incr('ti_successes') |
| self.refresh_from_db(lock_for_update=True) |
| self.state = State.SUCCESS |
| except AirflowSkipException: |
| self.refresh_from_db(lock_for_update=True) |
| self.state = State.SKIPPED |
| except AirflowException as e: |
| self.refresh_from_db() |
| # for case when task is marked as success externally |
| # current behavior doesn't hit the success callback |
| if self.state == State.SUCCESS: |
| return |
| else: |
| self.handle_failure(e, test_mode, context) |
| raise |
| except (Exception, KeyboardInterrupt) as e: |
| self.handle_failure(e, test_mode, context) |
| raise |
| |
| # Recording SUCCESS |
| self.end_date = datetime.utcnow() |
| self.set_duration() |
| if not test_mode: |
| session.add(Log(self.state, self)) |
| session.merge(self) |
| session.commit() |
| |
| # Success callback |
| try: |
| if task.on_success_callback: |
| task.on_success_callback(context) |
| except Exception as e3: |
| self.log.error("Failed when executing success callback") |
| self.log.exception(e3) |
| |
| session.commit() |
| |
| @provide_session |
| def run( |
| self, |
| verbose=True, |
| ignore_all_deps=False, |
| ignore_depends_on_past=False, |
| ignore_task_deps=False, |
| ignore_ti_state=False, |
| mark_success=False, |
| test_mode=False, |
| job_id=None, |
| pool=None, |
| session=None): |
| res = self._check_and_change_state_before_execution( |
| verbose=verbose, |
| ignore_all_deps=ignore_all_deps, |
| ignore_depends_on_past=ignore_depends_on_past, |
| ignore_task_deps=ignore_task_deps, |
| ignore_ti_state=ignore_ti_state, |
| mark_success=mark_success, |
| test_mode=test_mode, |
| job_id=job_id, |
| pool=pool, |
| session=session) |
| if res: |
| self._run_raw_task( |
| mark_success=mark_success, |
| test_mode=test_mode, |
| job_id=job_id, |
| pool=pool, |
| session=session) |
| |
| def dry_run(self): |
| task = self.task |
| task_copy = copy.copy(task) |
| self.task = task_copy |
| |
| self.render_templates() |
| task_copy.dry_run() |
| |
| def handle_failure(self, error, test_mode=False, context=None): |
| self.log.exception(error) |
| task = self.task |
| session = settings.Session() |
| self.end_date = datetime.utcnow() |
| self.set_duration() |
| Stats.incr('operator_failures_{}'.format(task.__class__.__name__), 1, 1) |
| Stats.incr('ti_failures') |
| if not test_mode: |
| session.add(Log(State.FAILED, self)) |
| |
| # Log failure duration |
| session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date)) |
| |
| # Let's go deeper |
| try: |
| # Since this function is called only when the TI state is running, |
| # try_number contains the current try_number (not the next). We |
| # only mark task instance as FAILED if the next task instance |
| # try_number exceeds the max_tries. |
| if task.retries and self.try_number <= self.max_tries: |
| self.state = State.UP_FOR_RETRY |
| self.log.info('Marking task as UP_FOR_RETRY') |
| if task.email_on_retry and task.email: |
| self.email_alert(error, is_retry=True) |
| else: |
| self.state = State.FAILED |
| if task.retries: |
| self.log.info('All retries failed; marking task as FAILED') |
| else: |
| self.log.info('Marking task as FAILED.') |
| if task.email_on_failure and task.email: |
| self.email_alert(error, is_retry=False) |
| except Exception as e2: |
| self.log.error('Failed to send email to: %s', task.email) |
| self.log.exception(e2) |
| |
| # Handling callbacks pessimistically |
| try: |
| if self.state == State.UP_FOR_RETRY and task.on_retry_callback: |
| task.on_retry_callback(context) |
| if self.state == State.FAILED and task.on_failure_callback: |
| task.on_failure_callback(context) |
| except Exception as e3: |
| self.log.error("Failed at executing callback") |
| self.log.exception(e3) |
| |
| if not test_mode: |
| session.merge(self) |
| session.commit() |
| self.log.error(str(error)) |
| |
| @provide_session |
| def get_template_context(self, session=None): |
| task = self.task |
| from airflow import macros |
| tables = None |
| if 'tables' in task.params: |
| tables = task.params['tables'] |
| |
| ds = self.execution_date.isoformat()[:10] |
| ts = self.execution_date.isoformat() |
| yesterday_ds = (self.execution_date - timedelta(1)).isoformat()[:10] |
| tomorrow_ds = (self.execution_date + timedelta(1)).isoformat()[:10] |
| |
| prev_execution_date = task.dag.previous_schedule(self.execution_date) |
| next_execution_date = task.dag.following_schedule(self.execution_date) |
| |
| ds_nodash = ds.replace('-', '') |
| ts_nodash = ts.replace('-', '').replace(':', '') |
| yesterday_ds_nodash = yesterday_ds.replace('-', '') |
| tomorrow_ds_nodash = tomorrow_ds.replace('-', '') |
| |
| ti_key_str = "{task.dag_id}__{task.task_id}__{ds_nodash}" |
| ti_key_str = ti_key_str.format(**locals()) |
| |
| params = {} |
| run_id = '' |
| dag_run = None |
| if hasattr(task, 'dag'): |
| if task.dag.params: |
| params.update(task.dag.params) |
| dag_run = ( |
| session.query(DagRun) |
| .filter_by( |
| dag_id=task.dag.dag_id, |
| execution_date=self.execution_date) |
| .first() |
| ) |
| run_id = dag_run.run_id if dag_run else None |
| session.expunge_all() |
| session.commit() |
| |
| if task.params: |
| params.update(task.params) |
| |
| class VariableAccessor: |
| """ |
| Wrapper around Variable. This way you can get variables in templates by using |
| {var.variable_name}. |
| """ |
| def __init__(self): |
| self.var = None |
| |
| def __getattr__(self, item): |
| self.var = Variable.get(item) |
| return self.var |
| |
| def __repr__(self): |
| return str(self.var) |
| |
| class VariableJsonAccessor: |
| def __init__(self): |
| self.var = None |
| |
| def __getattr__(self, item): |
| self.var = Variable.get(item, deserialize_json=True) |
| return self.var |
| |
| def __repr__(self): |
| return str(self.var) |
| |
| return { |
| 'dag': task.dag, |
| 'ds': ds, |
| 'ds_nodash': ds_nodash, |
| 'ts': ts, |
| 'ts_nodash': ts_nodash, |
| 'yesterday_ds': yesterday_ds, |
| 'yesterday_ds_nodash': yesterday_ds_nodash, |
| 'tomorrow_ds': tomorrow_ds, |
| 'tomorrow_ds_nodash': tomorrow_ds_nodash, |
| 'END_DATE': ds, |
| 'end_date': ds, |
| 'dag_run': dag_run, |
| 'run_id': run_id, |
| 'execution_date': self.execution_date, |
| 'prev_execution_date': prev_execution_date, |
| 'next_execution_date': next_execution_date, |
| 'latest_date': ds, |
| 'macros': macros, |
| 'params': params, |
| 'tables': tables, |
| 'task': task, |
| 'task_instance': self, |
| 'ti': self, |
| 'task_instance_key_str': ti_key_str, |
| 'conf': configuration, |
| 'test_mode': self.test_mode, |
| 'var': { |
| 'value': VariableAccessor(), |
| 'json': VariableJsonAccessor() |
| } |
| } |
| |
| def render_templates(self): |
| task = self.task |
| jinja_context = self.get_template_context() |
| if hasattr(self, 'task') and hasattr(self.task, 'dag'): |
| if self.task.dag.user_defined_macros: |
| jinja_context.update( |
| self.task.dag.user_defined_macros) |
| |
| rt = self.task.render_template # shortcut to method |
| for attr in task.__class__.template_fields: |
| content = getattr(task, attr) |
| if content: |
| rendered_content = rt(attr, content, jinja_context) |
| setattr(task, attr, rendered_content) |
| |
| def email_alert(self, exception, is_retry=False): |
| task = self.task |
| title = "Airflow alert: {self}".format(**locals()) |
| exception = str(exception).replace('\n', '<br>') |
| # For reporting purposes, we report based on 1-indexed, |
| # not 0-indexed lists (i.e. Try 1 instead of |
| # Try 0 for the first attempt). |
| body = ( |
| "Try {try_number} out of {max_tries}<br>" |
| "Exception:<br>{exception}<br>" |
| "Log: <a href='{self.log_url}'>Link</a><br>" |
| "Host: {self.hostname}<br>" |
| "Log file: {self.log_filepath}<br>" |
| "Mark success: <a href='{self.mark_success_url}'>Link</a><br>" |
| ).format(try_number=self.try_number, max_tries=self.max_tries + 1, **locals()) |
| send_email(task.email, title, body) |
| |
| def set_duration(self): |
| if self.end_date and self.start_date: |
| self.duration = (self.end_date - self.start_date).total_seconds() |
| else: |
| self.duration = None |
| |
| def xcom_push( |
| self, |
| key, |
| value, |
| execution_date=None): |
| """ |
| Make an XCom available for tasks to pull. |
| |
| :param key: A key for the XCom |
| :type key: string |
| :param value: A value for the XCom. The value is pickled and stored |
| in the database. |
| :type value: any pickleable object |
| :param execution_date: if provided, the XCom will not be visible until |
| this date. This can be used, for example, to send a message to a |
| task on a future date without it being immediately visible. |
| :type execution_date: datetime |
| """ |
| |
| if execution_date and execution_date < self.execution_date: |
| raise ValueError( |
| 'execution_date can not be in the past (current ' |
| 'execution_date is {}; received {})'.format( |
| self.execution_date, execution_date)) |
| |
| XCom.set( |
| key=key, |
| value=value, |
| task_id=self.task_id, |
| dag_id=self.dag_id, |
| execution_date=execution_date or self.execution_date) |
| |
| def xcom_pull( |
| self, |
| task_ids, |
| dag_id=None, |
| key=XCOM_RETURN_KEY, |
| include_prior_dates=False): |
| """ |
| Pull XComs that optionally meet certain criteria. |
| |
| The default value for `key` limits the search to XComs |
| that were returned by other tasks (as opposed to those that were pushed |
| manually). To remove this filter, pass key=None (or any desired value). |
| |
| If a single task_id string is provided, the result is the value of the |
| most recent matching XCom from that task_id. If multiple task_ids are |
| provided, a tuple of matching values is returned. None is returned |
| whenever no matches are found. |
| |
| :param key: A key for the XCom. If provided, only XComs with matching |
| keys will be returned. The default key is 'return_value', also |
| available as a constant XCOM_RETURN_KEY. This key is automatically |
| given to XComs returned by tasks (as opposed to being pushed |
| manually). To remove the filter, pass key=None. |
| :type key: string |
| :param task_ids: Only XComs from tasks with matching ids will be |
| pulled. Can pass None to remove the filter. |
| :type task_ids: string or iterable of strings (representing task_ids) |
| :param dag_id: If provided, only pulls XComs from this DAG. |
| If None (default), the DAG of the calling task is used. |
| :type dag_id: string |
| :param include_prior_dates: If False, only XComs from the current |
| execution_date are returned. If True, XComs from previous dates |
| are returned as well. |
| :type include_prior_dates: bool |
| """ |
| |
| if dag_id is None: |
| dag_id = self.dag_id |
| |
| pull_fn = functools.partial( |
| XCom.get_one, |
| execution_date=self.execution_date, |
| key=key, |
| dag_id=dag_id, |
| include_prior_dates=include_prior_dates) |
| |
| if is_container(task_ids): |
| return tuple(pull_fn(task_id=t) for t in task_ids) |
| else: |
| return pull_fn(task_id=task_ids) |
| |
| @provide_session |
| def get_num_running_task_instances(self, session): |
| TI = TaskInstance |
| return session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.task_id == self.task_id, |
| TI.state == State.RUNNING |
| ).count() |
| |
| |
| class TaskFail(Base): |
| """ |
| TaskFail tracks the failed run durations of each task instance. |
| """ |
| |
| __tablename__ = "task_fail" |
| |
| task_id = Column(String(ID_LEN), primary_key=True) |
| dag_id = Column(String(ID_LEN), primary_key=True) |
| execution_date = Column(DateTime, primary_key=True) |
| start_date = Column(DateTime) |
| end_date = Column(DateTime) |
| duration = Column(Float) |
| |
| def __init__(self, task, execution_date, start_date, end_date): |
| self.dag_id = task.dag_id |
| self.task_id = task.task_id |
| self.execution_date = execution_date |
| self.start_date = start_date |
| self.end_date = end_date |
| self.duration = (self.end_date - self.start_date).total_seconds() |
| |
| |
| class Log(Base): |
| """ |
| Used to actively log events to the database |
| """ |
| |
| __tablename__ = "log" |
| |
| id = Column(Integer, primary_key=True) |
| dttm = Column(DateTime) |
| dag_id = Column(String(ID_LEN)) |
| task_id = Column(String(ID_LEN)) |
| event = Column(String(30)) |
| execution_date = Column(DateTime) |
| owner = Column(String(500)) |
| extra = Column(Text) |
| |
| def __init__(self, event, task_instance, owner=None, extra=None, **kwargs): |
| self.dttm = datetime.utcnow() |
| self.event = event |
| self.extra = extra |
| |
| task_owner = None |
| |
| if task_instance: |
| self.dag_id = task_instance.dag_id |
| self.task_id = task_instance.task_id |
| self.execution_date = task_instance.execution_date |
| task_owner = task_instance.task.owner |
| |
| if 'task_id' in kwargs: |
| self.task_id = kwargs['task_id'] |
| if 'dag_id' in kwargs: |
| self.dag_id = kwargs['dag_id'] |
| if 'execution_date' in kwargs: |
| if kwargs['execution_date']: |
| self.execution_date = kwargs['execution_date'] |
| |
| self.owner = owner or task_owner |
| |
| |
| class SkipMixin(LoggingMixin): |
| def skip(self, dag_run, execution_date, tasks): |
| """ |
| Sets tasks instances to skipped from the same dag run. |
| |
| :param dag_run: the DagRun for which to set the tasks to skipped |
| :param execution_date: execution_date |
| :param tasks: tasks to skip (not task_ids) |
| """ |
| if not tasks: |
| return |
| |
| task_ids = [d.task_id for d in tasks] |
| now = datetime.utcnow() |
| session = settings.Session() |
| |
| if dag_run: |
| session.query(TaskInstance).filter( |
| TaskInstance.dag_id == dag_run.dag_id, |
| TaskInstance.execution_date == dag_run.execution_date, |
| TaskInstance.task_id.in_(task_ids) |
| ).update({TaskInstance.state : State.SKIPPED, |
| TaskInstance.start_date: now, |
| TaskInstance.end_date: now}, |
| synchronize_session=False) |
| session.commit() |
| else: |
| assert execution_date is not None, "Execution date is None and no dag run" |
| |
| self.log.warning("No DAG RUN present this should not happen") |
| # this is defensive against dag runs that are not complete |
| for task in tasks: |
| ti = TaskInstance(task, execution_date=execution_date) |
| ti.state = State.SKIPPED |
| ti.start_date = now |
| ti.end_date = now |
| session.merge(ti) |
| |
| session.commit() |
| session.close() |
| |
| |
| @functools.total_ordering |
| class BaseOperator(LoggingMixin): |
| """ |
| Abstract base class for all operators. Since operators create objects that |
| become node in the dag, BaseOperator contains many recursive methods for |
| dag crawling behavior. To derive this class, you are expected to override |
| the constructor as well as the 'execute' method. |
| |
| Operators derived from this class should perform or trigger certain tasks |
| synchronously (wait for completion). Example of operators could be an |
| operator the runs a Pig job (PigOperator), a sensor operator that |
| waits for a partition to land in Hive (HiveSensorOperator), or one that |
| moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these |
| operators (tasks) target specific operations, running specific scripts, |
| functions or data transfers. |
| |
| This class is abstract and shouldn't be instantiated. Instantiating a |
| class derived from this one results in the creation of a task object, |
| which ultimately becomes a node in DAG objects. Task dependencies should |
| be set by using the set_upstream and/or set_downstream methods. |
| |
| :param task_id: a unique, meaningful id for the task |
| :type task_id: string |
| :param owner: the owner of the task, using the unix username is recommended |
| :type owner: string |
| :param retries: the number of retries that should be performed before |
| failing the task |
| :type retries: int |
| :param retry_delay: delay between retries |
| :type retry_delay: timedelta |
| :param retry_exponential_backoff: allow progressive longer waits between |
| retries by using exponential backoff algorithm on retry delay (delay |
| will be converted into seconds) |
| :type retry_exponential_backoff: bool |
| :param max_retry_delay: maximum delay interval between retries |
| :type max_retry_delay: timedelta |
| :param start_date: The ``start_date`` for the task, determines |
| the ``execution_date`` for the first task instance. The best practice |
| is to have the start_date rounded |
| to your DAG's ``schedule_interval``. Daily jobs have their start_date |
| some day at 00:00:00, hourly jobs have their start_date at 00:00 |
| of a specific hour. Note that Airflow simply looks at the latest |
| ``execution_date`` and adds the ``schedule_interval`` to determine |
| the next ``execution_date``. It is also very important |
| to note that different tasks' dependencies |
| need to line up in time. If task A depends on task B and their |
| start_date are offset in a way that their execution_date don't line |
| up, A's dependencies will never be met. If you are looking to delay |
| a task, for example running a daily task at 2AM, look into the |
| ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using |
| dynamic ``start_date`` and recommend using fixed ones. Read the |
| FAQ entry about start_date for more information. |
| :type start_date: datetime |
| :param end_date: if specified, the scheduler won't go beyond this date |
| :type end_date: datetime |
| :param depends_on_past: when set to true, task instances will run |
| sequentially while relying on the previous task's schedule to |
| succeed. The task instance for the start_date is allowed to run. |
| :type depends_on_past: bool |
| :param wait_for_downstream: when set to true, an instance of task |
| X will wait for tasks immediately downstream of the previous instance |
| of task X to finish successfully before it runs. This is useful if the |
| different instances of a task X alter the same asset, and this asset |
| is used by tasks downstream of task X. Note that depends_on_past |
| is forced to True wherever wait_for_downstream is used. |
| :type wait_for_downstream: bool |
| :param queue: which queue to target when running this job. Not |
| all executors implement queue management, the CeleryExecutor |
| does support targeting specific queues. |
| :type queue: str |
| :param dag: a reference to the dag the task is attached to (if any) |
| :type dag: DAG |
| :param priority_weight: priority weight of this task against other task. |
| This allows the executor to trigger higher priority tasks before |
| others when things get backed up. |
| :type priority_weight: int |
| :param pool: the slot pool this task should run in, slot pools are a |
| way to limit concurrency for certain tasks |
| :type pool: str |
| :param sla: time by which the job is expected to succeed. Note that |
| this represents the ``timedelta`` after the period is closed. For |
| example if you set an SLA of 1 hour, the scheduler would send dan email |
| soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance |
| has not succeeded yet. |
| The scheduler pays special attention for jobs with an SLA and |
| sends alert |
| emails for sla misses. SLA misses are also recorded in the database |
| for future reference. All tasks that share the same SLA time |
| get bundled in a single email, sent soon after that time. SLA |
| notification are sent once and only once for each task instance. |
| :type sla: datetime.timedelta |
| :param execution_timeout: max time allowed for the execution of |
| this task instance, if it goes beyond it will raise and fail. |
| :type execution_timeout: datetime.timedelta |
| :param on_failure_callback: a function to be called when a task instance |
| of this task fails. a context dictionary is passed as a single |
| parameter to this function. Context contains references to related |
| objects to the task instance and is documented under the macros |
| section of the API. |
| :type on_failure_callback: callable |
| :param on_retry_callback: much like the ``on_failure_callback`` except |
| that it is executed when retries occur. |
| :param on_success_callback: much like the ``on_failure_callback`` except |
| that it is executed when the task succeeds. |
| :type on_success_callback: callable |
| :param trigger_rule: defines the rule by which dependencies are applied |
| for the task to get triggered. Options are: |
| ``{ all_success | all_failed | all_done | one_success | |
| one_failed | dummy}`` |
| default is ``all_success``. Options can be set as string or |
| using the constants defined in the static class |
| ``airflow.utils.TriggerRule`` |
| :type trigger_rule: str |
| :param resources: A map of resource parameter names (the argument names of the |
| Resources constructor) to their values. |
| :type resources: dict |
| :param run_as_user: unix username to impersonate while running the task |
| :type run_as_user: str |
| :param task_concurrency: When set, a task will be able to limit the concurrent |
| runs across execution_dates |
| :type task_concurrency: int |
| """ |
| |
| # For derived classes to define which fields will get jinjaified |
| template_fields = [] |
| # Defines which files extensions to look for in the templated fields |
| template_ext = [] |
| # Defines the color in the UI |
| ui_color = '#fff' |
| ui_fgcolor = '#000' |
| |
| @apply_defaults |
| def __init__( |
| self, |
| task_id, |
| owner=configuration.get('operators', 'DEFAULT_OWNER'), |
| email=None, |
| email_on_retry=True, |
| email_on_failure=True, |
| retries=0, |
| retry_delay=timedelta(seconds=300), |
| retry_exponential_backoff=False, |
| max_retry_delay=None, |
| start_date=None, |
| end_date=None, |
| schedule_interval=None, # not hooked as of now |
| depends_on_past=False, |
| wait_for_downstream=False, |
| dag=None, |
| params=None, |
| default_args=None, |
| adhoc=False, |
| priority_weight=1, |
| queue=configuration.get('celery', 'default_queue'), |
| pool=None, |
| sla=None, |
| execution_timeout=None, |
| on_failure_callback=None, |
| on_success_callback=None, |
| on_retry_callback=None, |
| trigger_rule=TriggerRule.ALL_SUCCESS, |
| resources=None, |
| run_as_user=None, |
| task_concurrency=None, |
| *args, |
| **kwargs): |
| |
| if args or kwargs: |
| # TODO remove *args and **kwargs in Airflow 2.0 |
| warnings.warn( |
| 'Invalid arguments were passed to {c}. Support for ' |
| 'passing such arguments will be dropped in Airflow 2.0. ' |
| 'Invalid arguments were:' |
| '\n*args: {a}\n**kwargs: {k}'.format( |
| c=self.__class__.__name__, a=args, k=kwargs), |
| category=PendingDeprecationWarning |
| ) |
| |
| validate_key(task_id) |
| self.task_id = task_id |
| self.owner = owner |
| self.email = email |
| self.email_on_retry = email_on_retry |
| self.email_on_failure = email_on_failure |
| self.start_date = start_date |
| if start_date and not isinstance(start_date, datetime): |
| self.log.warning("start_date for %s isn't datetime.datetime", self) |
| self.end_date = end_date |
| if not TriggerRule.is_valid(trigger_rule): |
| raise AirflowException( |
| "The trigger_rule must be one of {all_triggers}," |
| "'{d}.{t}'; received '{tr}'." |
| .format(all_triggers=TriggerRule.all_triggers, |
| d=dag.dag_id, t=task_id, tr=trigger_rule)) |
| |
| self.trigger_rule = trigger_rule |
| self.depends_on_past = depends_on_past |
| self.wait_for_downstream = wait_for_downstream |
| if wait_for_downstream: |
| self.depends_on_past = True |
| |
| if schedule_interval: |
| self.log.warning( |
| "schedule_interval is used for %s, though it has " |
| "been deprecated as a task parameter, you need to " |
| "specify it as a DAG parameter instead", |
| self |
| ) |
| self._schedule_interval = schedule_interval |
| self.retries = retries |
| self.queue = queue |
| self.pool = pool |
| self.sla = sla |
| self.execution_timeout = execution_timeout |
| self.on_failure_callback = on_failure_callback |
| self.on_success_callback = on_success_callback |
| self.on_retry_callback = on_retry_callback |
| if isinstance(retry_delay, timedelta): |
| self.retry_delay = retry_delay |
| else: |
| self.log.debug("Retry_delay isn't timedelta object, assuming secs") |
| self.retry_delay = timedelta(seconds=retry_delay) |
| self.retry_exponential_backoff = retry_exponential_backoff |
| self.max_retry_delay = max_retry_delay |
| self.params = params or {} # Available in templates! |
| self.adhoc = adhoc |
| self.priority_weight = priority_weight |
| self.resources = Resources(**(resources or {})) |
| self.run_as_user = run_as_user |
| self.task_concurrency = task_concurrency |
| |
| # Private attributes |
| self._upstream_task_ids = [] |
| self._downstream_task_ids = [] |
| |
| if not dag and _CONTEXT_MANAGER_DAG: |
| dag = _CONTEXT_MANAGER_DAG |
| if dag: |
| self.dag = dag |
| |
| self._comps = { |
| 'task_id', |
| 'dag_id', |
| 'owner', |
| 'email', |
| 'email_on_retry', |
| 'retry_delay', |
| 'retry_exponential_backoff', |
| 'max_retry_delay', |
| 'start_date', |
| 'schedule_interval', |
| 'depends_on_past', |
| 'wait_for_downstream', |
| 'adhoc', |
| 'priority_weight', |
| 'sla', |
| 'execution_timeout', |
| 'on_failure_callback', |
| 'on_success_callback', |
| 'on_retry_callback', |
| } |
| |
| def __eq__(self, other): |
| return ( |
| type(self) == type(other) and |
| all(self.__dict__.get(c, None) == other.__dict__.get(c, None) |
| for c in self._comps)) |
| |
| def __ne__(self, other): |
| return not self == other |
| |
| def __lt__(self, other): |
| return self.task_id < other.task_id |
| |
| def __hash__(self): |
| hash_components = [type(self)] |
| for c in self._comps: |
| val = getattr(self, c, None) |
| try: |
| hash(val) |
| hash_components.append(val) |
| except TypeError: |
| hash_components.append(repr(val)) |
| return hash(tuple(hash_components)) |
| |
| # Composing Operators ----------------------------------------------- |
| |
| def __rshift__(self, other): |
| """ |
| Implements Self >> Other == self.set_downstream(other) |
| |
| If "Other" is a DAG, the DAG is assigned to the Operator. |
| """ |
| if isinstance(other, DAG): |
| # if this dag is already assigned, do nothing |
| # otherwise, do normal dag assignment |
| if not (self.has_dag() and self.dag is other): |
| self.dag = other |
| else: |
| self.set_downstream(other) |
| return other |
| |
| def __lshift__(self, other): |
| """ |
| Implements Self << Other == self.set_upstream(other) |
| |
| If "Other" is a DAG, the DAG is assigned to the Operator. |
| """ |
| if isinstance(other, DAG): |
| # if this dag is already assigned, do nothing |
| # otherwise, do normal dag assignment |
| if not (self.has_dag() and self.dag is other): |
| self.dag = other |
| else: |
| self.set_upstream(other) |
| return other |
| |
| def __rrshift__(self, other): |
| """ |
| Called for [DAG] >> [Operator] because DAGs don't have |
| __rshift__ operators. |
| """ |
| self.__lshift__(other) |
| return self |
| |
| def __rlshift__(self, other): |
| """ |
| Called for [DAG] << [Operator] because DAGs don't have |
| __lshift__ operators. |
| """ |
| self.__rshift__(other) |
| return self |
| |
| # /Composing Operators --------------------------------------------- |
| |
| @property |
| def dag(self): |
| """ |
| Returns the Operator's DAG if set, otherwise raises an error |
| """ |
| if self.has_dag(): |
| return self._dag |
| else: |
| raise AirflowException( |
| 'Operator {} has not been assigned to a DAG yet'.format(self)) |
| |
| @dag.setter |
| def dag(self, dag): |
| """ |
| Operators can be assigned to one DAG, one time. Repeat assignments to |
| that same DAG are ok. |
| """ |
| if not isinstance(dag, DAG): |
| raise TypeError( |
| 'Expected DAG; received {}'.format(dag.__class__.__name__)) |
| elif self.has_dag() and self.dag is not dag: |
| raise AirflowException( |
| "The DAG assigned to {} can not be changed.".format(self)) |
| elif self.task_id not in dag.task_dict: |
| dag.add_task(self) |
| |
| self._dag = dag |
| |
| def has_dag(self): |
| """ |
| Returns True if the Operator has been assigned to a DAG. |
| """ |
| return getattr(self, '_dag', None) is not None |
| |
| @property |
| def dag_id(self): |
| if self.has_dag(): |
| return self.dag.dag_id |
| else: |
| return 'adhoc_' + self.owner |
| |
| @property |
| def deps(self): |
| """ |
| Returns the list of dependencies for the operator. These differ from execution |
| context dependencies in that they are specific to tasks and can be |
| extended/overridden by subclasses. |
| """ |
| return { |
| NotInRetryPeriodDep(), |
| PrevDagrunDep(), |
| TriggerRuleDep(), |
| } |
| |
| @property |
| def schedule_interval(self): |
| """ |
| The schedule interval of the DAG always wins over individual tasks so |
| that tasks within a DAG always line up. The task still needs a |
| schedule_interval as it may not be attached to a DAG. |
| """ |
| if self.has_dag(): |
| return self.dag._schedule_interval |
| else: |
| return self._schedule_interval |
| |
| @property |
| def priority_weight_total(self): |
| return sum([ |
| t.priority_weight |
| for t in self.get_flat_relatives(upstream=False) |
| ]) + self.priority_weight |
| |
| def pre_execute(self, context): |
| """ |
| This hook is triggered right before self.execute() is called. |
| """ |
| pass |
| |
| def execute(self, context): |
| """ |
| This is the main method to derive when creating an operator. |
| Context is the same dictionary used as when rendering jinja templates. |
| |
| Refer to get_template_context for more context. |
| """ |
| raise NotImplementedError() |
| |
| def post_execute(self, context, result=None): |
| """ |
| This hook is triggered right after self.execute() is called. |
| It is passed the execution context and any results returned by the |
| operator. |
| """ |
| pass |
| |
| def on_kill(self): |
| """ |
| Override this method to cleanup subprocesses when a task instance |
| gets killed. Any use of the threading, subprocess or multiprocessing |
| module within an operator needs to be cleaned up or it will leave |
| ghost processes behind. |
| """ |
| pass |
| |
| def __deepcopy__(self, memo): |
| """ |
| Hack sorting double chained task lists by task_id to avoid hitting |
| max_depth on deepcopy operations. |
| """ |
| sys.setrecursionlimit(5000) # TODO fix this in a better way |
| cls = self.__class__ |
| result = cls.__new__(cls) |
| memo[id(self)] = result |
| |
| for k, v in list(self.__dict__.items()): |
| if k not in ('user_defined_macros', 'user_defined_filters', 'params'): |
| setattr(result, k, copy.deepcopy(v, memo)) |
| result.params = self.params |
| if hasattr(self, 'user_defined_macros'): |
| result.user_defined_macros = self.user_defined_macros |
| if hasattr(self, 'user_defined_filters'): |
| result.user_defined_filters = self.user_defined_filters |
| return result |
| |
| def render_template_from_field(self, attr, content, context, jinja_env): |
| """ |
| Renders a template from a field. If the field is a string, it will |
| simply render the string and return the result. If it is a collection or |
| nested set of collections, it will traverse the structure and render |
| all strings in it. |
| """ |
| rt = self.render_template |
| if isinstance(content, six.string_types): |
| result = jinja_env.from_string(content).render(**context) |
| elif isinstance(content, (list, tuple)): |
| result = [rt(attr, e, context) for e in content] |
| elif isinstance(content, dict): |
| result = { |
| k: rt("{}[{}]".format(attr, k), v, context) |
| for k, v in list(content.items())} |
| else: |
| param_type = type(content) |
| msg = ( |
| "Type '{param_type}' used for parameter '{attr}' is " |
| "not supported for templating").format(**locals()) |
| raise AirflowException(msg) |
| return result |
| |
| def render_template(self, attr, content, context): |
| """ |
| Renders a template either from a file or directly in a field, and returns |
| the rendered result. |
| """ |
| jinja_env = self.dag.get_template_env() \ |
| if hasattr(self, 'dag') \ |
| else jinja2.Environment(cache_size=0) |
| |
| exts = self.__class__.template_ext |
| if ( |
| isinstance(content, six.string_types) and |
| any([content.endswith(ext) for ext in exts])): |
| return jinja_env.get_template(content).render(**context) |
| else: |
| return self.render_template_from_field(attr, content, context, jinja_env) |
| |
| def prepare_template(self): |
| """ |
| Hook that is triggered after the templated fields get replaced |
| by their content. If you need your operator to alter the |
| content of the file before the template is rendered, |
| it should override this method to do so. |
| """ |
| pass |
| |
| def resolve_template_files(self): |
| # Getting the content of files for template_field / template_ext |
| for attr in self.template_fields: |
| content = getattr(self, attr) |
| if content is not None and \ |
| isinstance(content, six.string_types) and \ |
| any([content.endswith(ext) for ext in self.template_ext]): |
| env = self.dag.get_template_env() |
| try: |
| setattr(self, attr, env.loader.get_source(env, content)[0]) |
| except Exception as e: |
| self.log.exception(e) |
| self.prepare_template() |
| |
| @property |
| def upstream_list(self): |
| """@property: list of tasks directly upstream""" |
| return [self.dag.get_task(tid) for tid in self._upstream_task_ids] |
| |
| @property |
| def upstream_task_ids(self): |
| return self._upstream_task_ids |
| |
| @property |
| def downstream_list(self): |
| """@property: list of tasks directly downstream""" |
| return [self.dag.get_task(tid) for tid in self._downstream_task_ids] |
| |
| @property |
| def downstream_task_ids(self): |
| return self._downstream_task_ids |
| |
| def clear(self, start_date=None, end_date=None, upstream=False, downstream=False): |
| """ |
| Clears the state of task instances associated with the task, following |
| the parameters specified. |
| """ |
| session = settings.Session() |
| |
| TI = TaskInstance |
| qry = session.query(TI).filter(TI.dag_id == self.dag_id) |
| |
| if start_date: |
| qry = qry.filter(TI.execution_date >= start_date) |
| if end_date: |
| qry = qry.filter(TI.execution_date <= end_date) |
| |
| tasks = [self.task_id] |
| |
| if upstream: |
| tasks += [ |
| t.task_id for t in self.get_flat_relatives(upstream=True)] |
| |
| if downstream: |
| tasks += [ |
| t.task_id for t in self.get_flat_relatives(upstream=False)] |
| |
| qry = qry.filter(TI.task_id.in_(tasks)) |
| |
| count = qry.count() |
| |
| clear_task_instances(qry.all(), session, dag=self.dag) |
| |
| session.commit() |
| session.close() |
| return count |
| |
| def get_task_instances(self, session, start_date=None, end_date=None): |
| """ |
| Get a set of task instance related to this task for a specific date |
| range. |
| """ |
| TI = TaskInstance |
| end_date = end_date or datetime.utcnow() |
| return session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.task_id == self.task_id, |
| TI.execution_date >= start_date, |
| TI.execution_date <= end_date, |
| ).order_by(TI.execution_date).all() |
| |
| def get_flat_relatives(self, upstream=False, l=None): |
| """ |
| Get a flat list of relatives, either upstream or downstream. |
| """ |
| if not l: |
| l = [] |
| for t in self.get_direct_relatives(upstream): |
| if not is_in(t, l): |
| l.append(t) |
| t.get_flat_relatives(upstream, l) |
| return l |
| |
| def detect_downstream_cycle(self, task=None): |
| """ |
| When invoked, this routine will raise an exception if a cycle is |
| detected downstream from self. It is invoked when tasks are added to |
| the DAG to detect cycles. |
| """ |
| if not task: |
| task = self |
| for t in self.get_direct_relatives(): |
| if task is t: |
| msg = "Cycle detected in DAG. Faulty task: {0}".format(task) |
| raise AirflowException(msg) |
| else: |
| t.detect_downstream_cycle(task=task) |
| return False |
| |
| def run( |
| self, |
| start_date=None, |
| end_date=None, |
| ignore_first_depends_on_past=False, |
| ignore_ti_state=False, |
| mark_success=False): |
| """ |
| Run a set of task instances for a date range. |
| """ |
| start_date = start_date or self.start_date |
| end_date = end_date or self.end_date or datetime.utcnow() |
| |
| for dt in self.dag.date_range(start_date, end_date=end_date): |
| TaskInstance(self, dt).run( |
| mark_success=mark_success, |
| ignore_depends_on_past=( |
| dt == start_date and ignore_first_depends_on_past), |
| ignore_ti_state=ignore_ti_state) |
| |
| def dry_run(self): |
| self.log.info('Dry run') |
| for attr in self.template_fields: |
| content = getattr(self, attr) |
| if content and isinstance(content, six.string_types): |
| self.log.info('Rendering template for %s', attr) |
| self.log.info(content) |
| |
| def get_direct_relatives(self, upstream=False): |
| """ |
| Get the direct relatives to the current task, upstream or |
| downstream. |
| """ |
| if upstream: |
| return self.upstream_list |
| else: |
| return self.downstream_list |
| |
| def __repr__(self): |
| return "<Task({self.__class__.__name__}): {self.task_id}>".format( |
| self=self) |
| |
| @property |
| def task_type(self): |
| return self.__class__.__name__ |
| |
| def append_only_new(self, l, item): |
| if any([item is t for t in l]): |
| raise AirflowException( |
| 'Dependency {self}, {item} already registered' |
| ''.format(**locals())) |
| else: |
| l.append(item) |
| |
| def _set_relatives(self, task_or_task_list, upstream=False): |
| try: |
| task_list = list(task_or_task_list) |
| except TypeError: |
| task_list = [task_or_task_list] |
| |
| for t in task_list: |
| if not isinstance(t, BaseOperator): |
| raise AirflowException( |
| "Relationships can only be set between " |
| "Operators; received {}".format(t.__class__.__name__)) |
| |
| # relationships can only be set if the tasks share a single DAG. Tasks |
| # without a DAG are assigned to that DAG. |
| dags = set(t.dag for t in [self] + task_list if t.has_dag()) |
| |
| if len(dags) > 1: |
| raise AirflowException( |
| 'Tried to set relationships between tasks in ' |
| 'more than one DAG: {}'.format(dags)) |
| elif len(dags) == 1: |
| dag = list(dags)[0] |
| else: |
| raise AirflowException( |
| "Tried to create relationships between tasks that don't have " |
| "DAGs yet. Set the DAG for at least one " |
| "task and try again: {}".format([self] + task_list)) |
| |
| if dag and not self.has_dag(): |
| self.dag = dag |
| |
| for task in task_list: |
| if dag and not task.has_dag(): |
| task.dag = dag |
| if upstream: |
| task.append_only_new(task._downstream_task_ids, self.task_id) |
| self.append_only_new(self._upstream_task_ids, task.task_id) |
| else: |
| self.append_only_new(self._downstream_task_ids, task.task_id) |
| task.append_only_new(task._upstream_task_ids, self.task_id) |
| |
| self.detect_downstream_cycle() |
| |
| def set_downstream(self, task_or_task_list): |
| """ |
| Set a task, or a task task to be directly downstream from the current |
| task. |
| """ |
| self._set_relatives(task_or_task_list, upstream=False) |
| |
| def set_upstream(self, task_or_task_list): |
| """ |
| Set a task, or a task task to be directly upstream from the current |
| task. |
| """ |
| self._set_relatives(task_or_task_list, upstream=True) |
| |
| def xcom_push( |
| self, |
| context, |
| key, |
| value, |
| execution_date=None): |
| """ |
| See TaskInstance.xcom_push() |
| """ |
| context['ti'].xcom_push( |
| key=key, |
| value=value, |
| execution_date=execution_date) |
| |
| def xcom_pull( |
| self, |
| context, |
| task_ids, |
| dag_id=None, |
| key=XCOM_RETURN_KEY, |
| include_prior_dates=None): |
| """ |
| See TaskInstance.xcom_pull() |
| """ |
| return context['ti'].xcom_pull( |
| key=key, |
| task_ids=task_ids, |
| dag_id=dag_id, |
| include_prior_dates=include_prior_dates) |
| |
| |
| class DagModel(Base): |
| |
| __tablename__ = "dag" |
| """ |
| These items are stored in the database for state related information |
| """ |
| dag_id = Column(String(ID_LEN), primary_key=True) |
| # A DAG can be paused from the UI / DB |
| # Set this default value of is_paused based on a configuration value! |
| is_paused_at_creation = configuration.getboolean('core', |
| 'dags_are_paused_at_creation') |
| is_paused = Column(Boolean, default=is_paused_at_creation) |
| # Whether the DAG is a subdag |
| is_subdag = Column(Boolean, default=False) |
| # Whether that DAG was seen on the last DagBag load |
| is_active = Column(Boolean, default=False) |
| # Last time the scheduler started |
| last_scheduler_run = Column(DateTime) |
| # Last time this DAG was pickled |
| last_pickled = Column(DateTime) |
| # Time when the DAG last received a refresh signal |
| # (e.g. the DAG's "refresh" button was clicked in the web UI) |
| last_expired = Column(DateTime) |
| # Whether (one of) the scheduler is scheduling this DAG at the moment |
| scheduler_lock = Column(Boolean) |
| # Foreign key to the latest pickle_id |
| pickle_id = Column(Integer) |
| # The location of the file containing the DAG object |
| fileloc = Column(String(2000)) |
| # String representing the owners |
| owners = Column(String(2000)) |
| |
| def __repr__(self): |
| return "<DAG: {self.dag_id}>".format(self=self) |
| |
| @classmethod |
| def get_current(cls, dag_id): |
| session = settings.Session() |
| obj = session.query(cls).filter(cls.dag_id == dag_id).first() |
| session.expunge_all() |
| session.commit() |
| session.close() |
| return obj |
| |
| |
| @functools.total_ordering |
| class DAG(BaseDag, LoggingMixin): |
| """ |
| A dag (directed acyclic graph) is a collection of tasks with directional |
| dependencies. A dag also has a schedule, a start end an end date |
| (optional). For each schedule, (say daily or hourly), the DAG needs to run |
| each individual tasks as their dependencies are met. Certain tasks have |
| the property of depending on their own past, meaning that they can't run |
| until their previous schedule (and upstream tasks) are completed. |
| |
| DAGs essentially act as namespaces for tasks. A task_id can only be |
| added once to a DAG. |
| |
| :param dag_id: The id of the DAG |
| :type dag_id: string |
| :param description: The description for the DAG to e.g. be shown on the webserver |
| :type description: string |
| :param schedule_interval: Defines how often that DAG runs, this |
| timedelta object gets added to your latest task instance's |
| execution_date to figure out the next schedule |
| :type schedule_interval: datetime.timedelta or |
| dateutil.relativedelta.relativedelta or str that acts as a cron |
| expression |
| :param start_date: The timestamp from which the scheduler will |
| attempt to backfill |
| :type start_date: datetime.datetime |
| :param end_date: A date beyond which your DAG won't run, leave to None |
| for open ended scheduling |
| :type end_date: datetime.datetime |
| :param template_searchpath: This list of folders (non relative) |
| defines where jinja will look for your templates. Order matters. |
| Note that jinja/airflow includes the path of your DAG file by |
| default |
| :type template_searchpath: string or list of stings |
| :param user_defined_macros: a dictionary of macros that will be exposed |
| in your jinja templates. For example, passing ``dict(foo='bar')`` |
| to this argument allows you to ``{{ foo }}`` in all jinja |
| templates related to this DAG. Note that you can pass any |
| type of object here. |
| :type user_defined_macros: dict |
| :param user_defined_filters: a dictionary of filters that will be exposed |
| in your jinja templates. For example, passing |
| ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows |
| you to ``{{ 'world' | hello }}`` in all jinja templates related to |
| this DAG. |
| :type user_defined_filters: dict |
| :param default_args: A dictionary of default parameters to be used |
| as constructor keyword parameters when initialising operators. |
| Note that operators have the same hook, and precede those defined |
| here, meaning that if your dict contains `'depends_on_past': True` |
| here and `'depends_on_past': False` in the operator's call |
| `default_args`, the actual value will be `False`. |
| :type default_args: dict |
| :param params: a dictionary of DAG level parameters that are made |
| accessible in templates, namespaced under `params`. These |
| params can be overridden at the task level. |
| :type params: dict |
| :param concurrency: the number of task instances allowed to run |
| concurrently |
| :type concurrency: int |
| :param max_active_runs: maximum number of active DAG runs, beyond this |
| number of DAG runs in a running state, the scheduler won't create |
| new active DAG runs |
| :type max_active_runs: int |
| :param dagrun_timeout: specify how long a DagRun should be up before |
| timing out / failing, so that new DagRuns can be created |
| :type dagrun_timeout: datetime.timedelta |
| :param sla_miss_callback: specify a function to call when reporting SLA |
| timeouts. |
| :type sla_miss_callback: types.FunctionType |
| :param default_view: Specify DAG default view (tree, graph, duration, gantt, landing_times) |
| :type default_view: string |
| :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT) |
| :type orientation: string |
| :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True |
| :type catchup: bool |
| """ |
| |
| def __init__( |
| self, dag_id, |
| description='', |
| schedule_interval=timedelta(days=1), |
| start_date=None, end_date=None, |
| full_filepath=None, |
| template_searchpath=None, |
| user_defined_macros=None, |
| user_defined_filters=None, |
| default_args=None, |
| concurrency=configuration.getint('core', 'dag_concurrency'), |
| max_active_runs=configuration.getint( |
| 'core', 'max_active_runs_per_dag'), |
| dagrun_timeout=None, |
| sla_miss_callback=None, |
| default_view=configuration.get('webserver', 'dag_default_view').lower(), |
| orientation=configuration.get('webserver', 'dag_orientation'), |
| catchup=configuration.getboolean('scheduler', 'catchup_by_default'), |
| params=None): |
| |
| self.user_defined_macros = user_defined_macros |
| self.user_defined_filters = user_defined_filters |
| self.default_args = default_args or {} |
| self.params = params or {} |
| |
| # merging potentially conflicting default_args['params'] into params |
| if 'params' in self.default_args: |
| self.params.update(self.default_args['params']) |
| del self.default_args['params'] |
| |
| validate_key(dag_id) |
| |
| # Properties from BaseDag |
| self._dag_id = dag_id |
| self._full_filepath = full_filepath if full_filepath else '' |
| self._concurrency = concurrency |
| self._pickle_id = None |
| |
| self._description = description |
| # set file location to caller source path |
| self.fileloc = sys._getframe().f_back.f_code.co_filename |
| self.task_dict = dict() |
| self.start_date = start_date |
| self.end_date = end_date |
| self.schedule_interval = schedule_interval |
| if schedule_interval in cron_presets: |
| self._schedule_interval = cron_presets.get(schedule_interval) |
| elif schedule_interval == '@once': |
| self._schedule_interval = None |
| else: |
| self._schedule_interval = schedule_interval |
| if isinstance(template_searchpath, six.string_types): |
| template_searchpath = [template_searchpath] |
| self.template_searchpath = template_searchpath |
| self.parent_dag = None # Gets set when DAGs are loaded |
| self.last_loaded = datetime.utcnow() |
| self.safe_dag_id = dag_id.replace('.', '__dot__') |
| self.max_active_runs = max_active_runs |
| self.dagrun_timeout = dagrun_timeout |
| self.sla_miss_callback = sla_miss_callback |
| self.default_view = default_view |
| self.orientation = orientation |
| self.catchup = catchup |
| self.is_subdag = False # DagBag.bag_dag() will set this to True if appropriate |
| |
| self.partial = False |
| |
| self._comps = { |
| 'dag_id', |
| 'task_ids', |
| 'parent_dag', |
| 'start_date', |
| 'schedule_interval', |
| 'full_filepath', |
| 'template_searchpath', |
| 'last_loaded', |
| } |
| |
| def __repr__(self): |
| return "<DAG: {self.dag_id}>".format(self=self) |
| |
| def __eq__(self, other): |
| return ( |
| type(self) == type(other) and |
| # Use getattr() instead of __dict__ as __dict__ doesn't return |
| # correct values for properties. |
| all(getattr(self, c, None) == getattr(other, c, None) |
| for c in self._comps)) |
| |
| def __ne__(self, other): |
| return not self == other |
| |
| def __lt__(self, other): |
| return self.dag_id < other.dag_id |
| |
| def __hash__(self): |
| hash_components = [type(self)] |
| for c in self._comps: |
| # task_ids returns a list and lists can't be hashed |
| if c == 'task_ids': |
| val = tuple(self.task_dict.keys()) |
| else: |
| val = getattr(self, c, None) |
| try: |
| hash(val) |
| hash_components.append(val) |
| except TypeError: |
| hash_components.append(repr(val)) |
| return hash(tuple(hash_components)) |
| |
| # Context Manager ----------------------------------------------- |
| |
| def __enter__(self): |
| global _CONTEXT_MANAGER_DAG |
| self._old_context_manager_dag = _CONTEXT_MANAGER_DAG |
| _CONTEXT_MANAGER_DAG = self |
| return self |
| |
| def __exit__(self, _type, _value, _tb): |
| global _CONTEXT_MANAGER_DAG |
| _CONTEXT_MANAGER_DAG = self._old_context_manager_dag |
| |
| # /Context Manager ---------------------------------------------- |
| |
| def date_range(self, start_date, num=None, end_date=datetime.utcnow()): |
| if num: |
| end_date = None |
| return utils_date_range( |
| start_date=start_date, end_date=end_date, |
| num=num, delta=self._schedule_interval) |
| |
| def following_schedule(self, dttm): |
| if isinstance(self._schedule_interval, six.string_types): |
| cron = croniter(self._schedule_interval, dttm) |
| return cron.get_next(datetime) |
| elif isinstance(self._schedule_interval, timedelta): |
| return dttm + self._schedule_interval |
| |
| def previous_schedule(self, dttm): |
| if isinstance(self._schedule_interval, six.string_types): |
| cron = croniter(self._schedule_interval, dttm) |
| return cron.get_prev(datetime) |
| elif isinstance(self._schedule_interval, timedelta): |
| return dttm - self._schedule_interval |
| |
| def get_run_dates(self, start_date, end_date=None): |
| """ |
| Returns a list of dates between the interval received as parameter using this |
| dag's schedule interval. Returned dates can be used for execution dates. |
| |
| :param start_date: the start date of the interval |
| :type start_date: datetime |
| :param end_date: the end date of the interval, defaults to datetime.utcnow() |
| :type end_date: datetime |
| :return: a list of dates within the interval following the dag's schedule |
| :rtype: list |
| """ |
| run_dates = [] |
| |
| using_start_date = start_date |
| using_end_date = end_date |
| |
| # dates for dag runs |
| using_start_date = using_start_date or min([t.start_date for t in self.tasks]) |
| using_end_date = using_end_date or datetime.utcnow() |
| |
| # next run date for a subdag isn't relevant (schedule_interval for subdags |
| # is ignored) so we use the dag run's start date in the case of a subdag |
| next_run_date = (self.normalize_schedule(using_start_date) |
| if not self.is_subdag else using_start_date) |
| |
| while next_run_date and next_run_date <= using_end_date: |
| run_dates.append(next_run_date) |
| next_run_date = self.following_schedule(next_run_date) |
| |
| return run_dates |
| |
| def normalize_schedule(self, dttm): |
| """ |
| Returns dttm + interval unless dttm is first interval then it returns dttm |
| """ |
| following = self.following_schedule(dttm) |
| |
| # in case of @once |
| if not following: |
| return dttm |
| |
| if self.previous_schedule(following) != dttm: |
| return following |
| |
| return dttm |
| |
| @provide_session |
| def get_last_dagrun(self, session=None, include_externally_triggered=False): |
| """ |
| Returns the last dag run for this dag, None if there was none. |
| Last dag run can be any type of run eg. scheduled or backfilled. |
| Overridden DagRuns are ignored |
| """ |
| DR = DagRun |
| qry = session.query(DR).filter( |
| DR.dag_id == self.dag_id, |
| ) |
| if not include_externally_triggered: |
| qry = qry.filter(DR.external_trigger.__eq__(False)) |
| |
| qry = qry.order_by(DR.execution_date.desc()) |
| |
| last = qry.first() |
| |
| return last |
| |
| @property |
| def dag_id(self): |
| return self._dag_id |
| |
| @dag_id.setter |
| def dag_id(self, value): |
| self._dag_id = value |
| |
| @property |
| def full_filepath(self): |
| return self._full_filepath |
| |
| @full_filepath.setter |
| def full_filepath(self, value): |
| self._full_filepath = value |
| |
| @property |
| def concurrency(self): |
| return self._concurrency |
| |
| @concurrency.setter |
| def concurrency(self, value): |
| self._concurrency = value |
| |
| @property |
| def description(self): |
| return self._description |
| |
| @property |
| def pickle_id(self): |
| return self._pickle_id |
| |
| @pickle_id.setter |
| def pickle_id(self, value): |
| self._pickle_id = value |
| |
| @property |
| def tasks(self): |
| return list(self.task_dict.values()) |
| |
| @tasks.setter |
| def tasks(self, val): |
| raise AttributeError( |
| 'DAG.tasks can not be modified. Use dag.add_task() instead.') |
| |
| @property |
| def task_ids(self): |
| return list(self.task_dict.keys()) |
| |
| @property |
| def active_task_ids(self): |
| return list(k for k, v in self.task_dict.items() if not v.adhoc) |
| |
| @property |
| def active_tasks(self): |
| return [t for t in self.tasks if not t.adhoc] |
| |
| @property |
| def filepath(self): |
| """ |
| File location of where the dag object is instantiated |
| """ |
| fn = self.full_filepath.replace(settings.DAGS_FOLDER + '/', '') |
| fn = fn.replace(os.path.dirname(__file__) + '/', '') |
| return fn |
| |
| @property |
| def folder(self): |
| """ |
| Folder location of where the dag object is instantiated |
| """ |
| return os.path.dirname(self.full_filepath) |
| |
| @property |
| def owner(self): |
| return ", ".join(list(set([t.owner for t in self.tasks]))) |
| |
| @property |
| @provide_session |
| def concurrency_reached(self, session=None): |
| """ |
| Returns a boolean indicating whether the concurrency limit for this DAG |
| has been reached |
| """ |
| TI = TaskInstance |
| qry = session.query(func.count(TI.task_id)).filter( |
| TI.dag_id == self.dag_id, |
| TI.task_id.in_(self.task_ids), |
| TI.state == State.RUNNING, |
| ) |
| return qry.scalar() >= self.concurrency |
| |
| @property |
| @provide_session |
| def is_paused(self, session=None): |
| """ |
| Returns a boolean indicating whether this DAG is paused |
| """ |
| qry = session.query(DagModel).filter( |
| DagModel.dag_id == self.dag_id) |
| return qry.value('is_paused') |
| |
| @provide_session |
| def get_active_runs(self, session=None): |
| """ |
| Returns a list of dag run execution dates currently running |
| |
| :param session: |
| :return: List of execution dates |
| """ |
| runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING) |
| |
| active_dates = [] |
| for run in runs: |
| active_dates.append(run.execution_date) |
| |
| return active_dates |
| |
| @provide_session |
| def get_num_active_runs(self, external_trigger=None, session=None): |
| """ |
| Returns the number of active "running" dag runs |
| |
| :param external_trigger: True for externally triggered active dag runs |
| :type external_trigger: bool |
| :param session: |
| :return: number greater than 0 for active dag runs |
| """ |
| query = (session |
| .query(DagRun) |
| .filter(DagRun.dag_id == self.dag_id) |
| .filter(DagRun.state == State.RUNNING)) |
| |
| if external_trigger is not None: |
| query = query.filter(DagRun.external_trigger == external_trigger) |
| |
| return query.count() |
| |
| @provide_session |
| def get_dagrun(self, execution_date, session=None): |
| """ |
| Returns the dag run for a given execution date if it exists, otherwise |
| none. |
| |
| :param execution_date: The execution date of the DagRun to find. |
| :param session: |
| :return: The DagRun if found, otherwise None. |
| """ |
| dagrun = ( |
| session.query(DagRun) |
| .filter( |
| DagRun.dag_id == self.dag_id, |
| DagRun.execution_date == execution_date) |
| .first()) |
| |
| return dagrun |
| |
| @property |
| def latest_execution_date(self): |
| """ |
| Returns the latest date for which at least one dag run exists |
| """ |
| session = settings.Session() |
| execution_date = session.query(func.max(DagRun.execution_date)).filter( |
| DagRun.dag_id == self.dag_id |
| ).scalar() |
| session.commit() |
| session.close() |
| return execution_date |
| |
| @property |
| def subdags(self): |
| """ |
| Returns a list of the subdag objects associated to this DAG |
| """ |
| # Check SubDag for class but don't check class directly, see |
| # https://github.com/airbnb/airflow/issues/1168 |
| from airflow.operators.subdag_operator import SubDagOperator |
| l = [] |
| for task in self.tasks: |
| if (isinstance(task, SubDagOperator) or |
| #TODO remove in Airflow 2.0 |
| type(task).__name__ == 'SubDagOperator'): |
| l.append(task.subdag) |
| l += task.subdag.subdags |
| return l |
| |
| def resolve_template_files(self): |
| for t in self.tasks: |
| t.resolve_template_files() |
| |
| def get_template_env(self): |
| """ |
| Returns a jinja2 Environment while taking into account the DAGs |
| template_searchpath, user_defined_macros and user_defined_filters |
| """ |
| searchpath = [self.folder] |
| if self.template_searchpath: |
| searchpath += self.template_searchpath |
| |
| env = jinja2.Environment( |
| loader=jinja2.FileSystemLoader(searchpath), |
| extensions=["jinja2.ext.do"], |
| cache_size=0) |
| if self.user_defined_macros: |
| env.globals.update(self.user_defined_macros) |
| if self.user_defined_filters: |
| env.filters.update(self.user_defined_filters) |
| |
| return env |
| |
| def set_dependency(self, upstream_task_id, downstream_task_id): |
| """ |
| Simple utility method to set dependency between two tasks that |
| already have been added to the DAG using add_task() |
| """ |
| self.get_task(upstream_task_id).set_downstream( |
| self.get_task(downstream_task_id)) |
| |
| def get_task_instances( |
| self, session, start_date=None, end_date=None, state=None): |
| TI = TaskInstance |
| if not start_date: |
| start_date = (datetime.utcnow() - timedelta(30)).date() |
| start_date = datetime.combine(start_date, datetime.min.time()) |
| end_date = end_date or datetime.utcnow() |
| tis = session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.execution_date >= start_date, |
| TI.execution_date <= end_date, |
| TI.task_id.in_([t.task_id for t in self.tasks]), |
| ) |
| if state: |
| tis = tis.filter(TI.state == state) |
| tis = tis.order_by(TI.execution_date).all() |
| return tis |
| |
| @property |
| def roots(self): |
| return [t for t in self.tasks if not t.downstream_list] |
| |
| def topological_sort(self): |
| """ |
| Sorts tasks in topographical order, such that a task comes after any of its |
| upstream dependencies. |
| |
| Heavily inspired by: |
| http://blog.jupo.org/2012/04/06/topological-sorting-acyclic-directed-graphs/ |
| |
| :return: list of tasks in topological order |
| """ |
| |
| # copy the the tasks so we leave it unmodified |
| graph_unsorted = self.tasks[:] |
| |
| graph_sorted = [] |
| |
| # special case |
| if len(self.tasks) == 0: |
| return tuple(graph_sorted) |
| |
| # Run until the unsorted graph is empty. |
| while graph_unsorted: |
| # Go through each of the node/edges pairs in the unsorted |
| # graph. If a set of edges doesn't contain any nodes that |
| # haven't been resolved, that is, that are still in the |
| # unsorted graph, remove the pair from the unsorted graph, |
| # and append it to the sorted graph. Note here that by using |
| # using the items() method for iterating, a copy of the |
| # unsorted graph is used, allowing us to modify the unsorted |
| # graph as we move through it. We also keep a flag for |
| # checking that that graph is acyclic, which is true if any |
| # nodes are resolved during each pass through the graph. If |
| # not, we need to bail out as the graph therefore can't be |
| # sorted. |
| acyclic = False |
| for node in list(graph_unsorted): |
| for edge in node.upstream_list: |
| if edge in graph_unsorted: |
| break |
| # no edges in upstream tasks |
| else: |
| acyclic = True |
| graph_unsorted.remove(node) |
| graph_sorted.append(node) |
| |
| if not acyclic: |
| raise AirflowException("A cyclic dependency occurred in dag: {}" |
| .format(self.dag_id)) |
| |
| return tuple(graph_sorted) |
| |
| @provide_session |
| def set_dag_runs_state( |
| self, state=State.RUNNING, session=None): |
| drs = session.query(DagModel).filter_by(dag_id=self.dag_id).all() |
| dirty_ids = [] |
| for dr in drs: |
| dr.state = state |
| dirty_ids.append(dr.dag_id) |
| DagStat.update(dirty_ids, session=session) |
| |
| def clear( |
| self, start_date=None, end_date=None, |
| only_failed=False, |
| only_running=False, |
| confirm_prompt=False, |
| include_subdags=True, |
| reset_dag_runs=True, |
| dry_run=False): |
| """ |
| Clears a set of task instances associated with the current dag for |
| a specified date range. |
| """ |
| session = settings.Session() |
| TI = TaskInstance |
| tis = session.query(TI) |
| if include_subdags: |
| # Crafting the right filter for dag_id and task_ids combo |
| conditions = [] |
| for dag in self.subdags + [self]: |
| conditions.append( |
| TI.dag_id.like(dag.dag_id) & |
| TI.task_id.in_(dag.task_ids) |
| ) |
| tis = tis.filter(or_(*conditions)) |
| else: |
| tis = session.query(TI).filter(TI.dag_id == self.dag_id) |
| tis = tis.filter(TI.task_id.in_(self.task_ids)) |
| |
| if start_date: |
| tis = tis.filter(TI.execution_date >= start_date) |
| if end_date: |
| tis = tis.filter(TI.execution_date <= end_date) |
| if only_failed: |
| tis = tis.filter(TI.state == State.FAILED) |
| if only_running: |
| tis = tis.filter(TI.state == State.RUNNING) |
| |
| if dry_run: |
| tis = tis.all() |
| session.expunge_all() |
| return tis |
| |
| count = tis.count() |
| do_it = True |
| if count == 0: |
| return 0 |
| if confirm_prompt: |
| ti_list = "\n".join([str(t) for t in tis]) |
| question = ( |
| "You are about to delete these {count} tasks:\n" |
| "{ti_list}\n\n" |
| "Are you sure? (yes/no): ").format(**locals()) |
| do_it = utils.helpers.ask_yesno(question) |
| |
| if do_it: |
| clear_task_instances(tis.all(), session, dag=self) |
| if reset_dag_runs: |
| self.set_dag_runs_state(session=session) |
| else: |
| count = 0 |
| print("Bail. Nothing was cleared.") |
| |
| session.commit() |
| session.close() |
| return count |
| |
| @classmethod |
| def clear_dags( |
| cls, dags, |
| start_date=None, |
| end_date=None, |
| only_failed=False, |
| only_running=False, |
| confirm_prompt=False, |
| include_subdags=True, |
| reset_dag_runs=True, |
| dry_run=False): |
| all_tis = [] |
| for dag in dags: |
| tis = dag.clear( |
| start_date=start_date, |
| end_date=end_date, |
| only_failed=only_failed, |
| only_running=only_running, |
| confirm_prompt=False, |
| include_subdags=include_subdags, |
| reset_dag_runs=reset_dag_runs, |
| dry_run=True) |
| all_tis.extend(tis) |
| |
| if dry_run: |
| return all_tis |
| |
| count = len(all_tis) |
| do_it = True |
| if count == 0: |
| print("Nothing to clear.") |
| return 0 |
| if confirm_prompt: |
| ti_list = "\n".join([str(t) for t in all_tis]) |
| question = ( |
| "You are about to delete these {} tasks:\n" |
| "{}\n\n" |
| "Are you sure? (yes/no): ").format(count, ti_list) |
| do_it = utils.helpers.ask_yesno(question) |
| |
| if do_it: |
| for dag in dags: |
| dag.clear(start_date=start_date, |
| end_date=end_date, |
| only_failed=only_failed, |
| only_running=only_running, |
| confirm_prompt=False, |
| include_subdags=include_subdags, |
| reset_dag_runs=reset_dag_runs, |
| dry_run=False) |
| else: |
| count = 0 |
| print("Bail. Nothing was cleared.") |
| return count |
| |
| def __deepcopy__(self, memo): |
| # Swiwtcharoo to go around deepcopying objects coming through the |
| # backdoor |
| cls = self.__class__ |
| result = cls.__new__(cls) |
| memo[id(self)] = result |
| for k, v in list(self.__dict__.items()): |
| if k not in ('user_defined_macros', 'user_defined_filters', 'params'): |
| setattr(result, k, copy.deepcopy(v, memo)) |
| |
| result.user_defined_macros = self.user_defined_macros |
| result.user_defined_filters = self.user_defined_filters |
| result.params = self.params |
| return result |
| |
| def sub_dag(self, task_regex, include_downstream=False, |
| include_upstream=True): |
| """ |
| Returns a subset of the current dag as a deep copy of the current dag |
| based on a regex that should match one or many tasks, and includes |
| upstream and downstream neighbours based on the flag passed. |
| """ |
| |
| dag = copy.deepcopy(self) |
| |
| regex_match = [ |
| t for t in dag.tasks if re.findall(task_regex, t.task_id)] |
| also_include = [] |
| for t in regex_match: |
| if include_downstream: |
| also_include += t.get_flat_relatives(upstream=False) |
| if include_upstream: |
| also_include += t.get_flat_relatives(upstream=True) |
| |
| # Compiling the unique list of tasks that made the cut |
| dag.task_dict = {t.task_id: t for t in regex_match + also_include} |
| for t in dag.tasks: |
| # Removing upstream/downstream references to tasks that did not |
| # made the cut |
| t._upstream_task_ids = [ |
| tid for tid in t._upstream_task_ids if tid in dag.task_ids] |
| t._downstream_task_ids = [ |
| tid for tid in t._downstream_task_ids if tid in dag.task_ids] |
| |
| if len(dag.tasks) < len(self.tasks): |
| dag.partial = True |
| |
| return dag |
| |
| def has_task(self, task_id): |
| return task_id in (t.task_id for t in self.tasks) |
| |
| def get_task(self, task_id): |
| if task_id in self.task_dict: |
| return self.task_dict[task_id] |
| raise AirflowException("Task {task_id} not found".format(**locals())) |
| |
| @provide_session |
| def pickle_info(self, session=None): |
| d = {} |
| d['is_picklable'] = True |
| try: |
| dttm = datetime.utcnow() |
| pickled = pickle.dumps(self) |
| d['pickle_len'] = len(pickled) |
| d['pickling_duration'] = "{}".format(datetime.utcnow() - dttm) |
| except Exception as e: |
| self.log.debug(e) |
| d['is_picklable'] = False |
| d['stacktrace'] = traceback.format_exc() |
| return d |
| |
| @provide_session |
| def pickle(self, session=None): |
| dag = session.query( |
| DagModel).filter(DagModel.dag_id == self.dag_id).first() |
| dp = None |
| if dag and dag.pickle_id: |
| dp = session.query(DagPickle).filter( |
| DagPickle.id == dag.pickle_id).first() |
| if not dp or dp.pickle != self: |
| dp = DagPickle(dag=self) |
| session.add(dp) |
| self.last_pickled = datetime.utcnow() |
| session.commit() |
| self.pickle_id = dp.id |
| |
| return dp |
| |
| def tree_view(self): |
| """ |
| Shows an ascii tree representation of the DAG |
| """ |
| def get_downstream(task, level=0): |
| print((" " * level * 4) + str(task)) |
| level += 1 |
| for t in task.upstream_list: |
| get_downstream(t, level) |
| |
| for t in self.roots: |
| get_downstream(t) |
| |
| def add_task(self, task): |
| """ |
| Add a task to the DAG |
| |
| :param task: the task you want to add |
| :type task: task |
| """ |
| if not self.start_date and not task.start_date: |
| raise AirflowException("Task is missing the start_date parameter") |
| # if the task has no start date, assign it the same as the DAG |
| elif not task.start_date: |
| task.start_date = self.start_date |
| # otherwise, the task will start on the later of its own start date and |
| # the DAG's start date |
| elif self.start_date: |
| task.start_date = max(task.start_date, self.start_date) |
| |
| # if the task has no end date, assign it the same as the dag |
| if not task.end_date: |
| task.end_date = self.end_date |
| # otherwise, the task will end on the earlier of its own end date and |
| # the DAG's end date |
| elif task.end_date and self.end_date: |
| task.end_date = min(task.end_date, self.end_date) |
| |
| if task.task_id in self.task_dict: |
| # TODO: raise an error in Airflow 2.0 |
| warnings.warn( |
| 'The requested task could not be added to the DAG because a ' |
| 'task with task_id {} is already in the DAG. Starting in ' |
| 'Airflow 2.0, trying to overwrite a task will raise an ' |
| 'exception.'.format(task.task_id), |
| category=PendingDeprecationWarning) |
| else: |
| self.tasks.append(task) |
| self.task_dict[task.task_id] = task |
| task.dag = self |
| |
| self.task_count = len(self.tasks) |
| |
| def add_tasks(self, tasks): |
| """ |
| Add a list of tasks to the DAG |
| |
| :param tasks: a lit of tasks you want to add |
| :type tasks: list of tasks |
| """ |
| for task in tasks: |
| self.add_task(task) |
| |
| def db_merge(self): |
| BO = BaseOperator |
| session = settings.Session() |
| tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all() |
| for t in tasks: |
| session.delete(t) |
| session.commit() |
| session.merge(self) |
| session.commit() |
| |
| def run( |
| self, |
| start_date=None, |
| end_date=None, |
| mark_success=False, |
| include_adhoc=False, |
| local=False, |
| executor=None, |
| donot_pickle=configuration.getboolean('core', 'donot_pickle'), |
| ignore_task_deps=False, |
| ignore_first_depends_on_past=False, |
| pool=None, |
| delay_on_limit_secs=1.0): |
| """ |
| Runs the DAG. |
| |
| :param start_date: the start date of the range to run |
| :type start_date: datetime |
| :param end_date: the end date of the range to run |
| :type end_date: datetime |
| :param mark_success: True to mark jobs as succeeded without running them |
| :type mark_success: bool |
| :param include_adhoc: True to include dags with the adhoc parameter |
| :type include_adhoc: bool |
| :param local: True to run the tasks using the LocalExecutor |
| :type local: bool |
| :param executor: The executor instance to run the tasks |
| :type executor: BaseExecutor |
| :param donot_pickle: True to avoid pickling DAG object and send to workers |
| :type donot_pickle: bool |
| :param ignore_task_deps: True to skip upstream tasks |
| :type ignore_task_deps: bool |
| :param ignore_first_depends_on_past: True to ignore depends_on_past |
| dependencies for the first set of tasks only |
| :type ignore_first_depends_on_past: bool |
| :param pool: Resource pool to use |
| :type pool: string |
| :param delay_on_limit_secs: Time in seconds to wait before next attempt to run |
| dag run when max_active_runs limit has been reached |
| :type delay_on_limit_secs: float |
| """ |
| from airflow.jobs import BackfillJob |
| if not executor and local: |
| executor = LocalExecutor() |
| elif not executor: |
| executor = GetDefaultExecutor() |
| job = BackfillJob( |
| self, |
| start_date=start_date, |
| end_date=end_date, |
| mark_success=mark_success, |
| include_adhoc=include_adhoc, |
| executor=executor, |
| donot_pickle=donot_pickle, |
| ignore_task_deps=ignore_task_deps, |
| ignore_first_depends_on_past=ignore_first_depends_on_past, |
| pool=pool, |
| delay_on_limit_secs=delay_on_limit_secs) |
| job.run() |
| |
| def cli(self): |
| """ |
| Exposes a CLI specific to this DAG |
| """ |
| from airflow.bin import cli |
| parser = cli.CLIFactory.get_parser(dag_parser=True) |
| args = parser.parse_args() |
| args.func(args, self) |
| |
| @provide_session |
| def create_dagrun(self, |
| run_id, |
| state, |
| execution_date=None, |
| start_date=None, |
| external_trigger=False, |
| conf=None, |
| session=None): |
| """ |
| Creates a dag run from this dag including the tasks associated with this dag. |
| Returns the dag run. |
| |
| :param run_id: defines the the run id for this dag run |
| :type run_id: string |
| :param execution_date: the execution date of this dag run |
| :type execution_date: datetime |
| :param state: the state of the dag run |
| :type state: State |
| :param start_date: the date this dag run should be evaluated |
| :type start_date: datetime |
| :param external_trigger: whether this dag run is externally triggered |
| :type external_trigger: bool |
| :param session: database session |
| :type session: Session |
| """ |
| run = DagRun( |
| dag_id=self.dag_id, |
| run_id=run_id, |
| execution_date=execution_date, |
| start_date=start_date, |
| external_trigger=external_trigger, |
| conf=conf, |
| state=state |
| ) |
| session.add(run) |
| |
| DagStat.set_dirty(dag_id=self.dag_id, session=session) |
| |
| session.commit() |
| |
| run.dag = self |
| |
| # create the associated task instances |
| # state is None at the moment of creation |
| run.verify_integrity(session=session) |
| |
| run.refresh_from_db() |
| |
| return run |
| |
| @provide_session |
| def sync_to_db(self, owner=None, sync_time=None, session=None): |
| """ |
| Save attributes about this DAG to the DB. Note that this method |
| can be called for both DAGs and SubDAGs. A SubDag is actually a |
| SubDagOperator. |
| |
| :param dag: the DAG object to save to the DB |
| :type dag: DAG |
| :param sync_time: The time that the DAG should be marked as sync'ed |
| :type sync_time: datetime |
| :return: None |
| """ |
| |
| if owner is None: |
| owner = self.owner |
| if sync_time is None: |
| sync_time = datetime.utcnow() |
| |
| orm_dag = session.query( |
| DagModel).filter(DagModel.dag_id == self.dag_id).first() |
| if not orm_dag: |
| orm_dag = DagModel(dag_id=self.dag_id) |
| self.log.info("Creating ORM DAG for %s", self.dag_id) |
| orm_dag.fileloc = self.fileloc |
| orm_dag.is_subdag = self.is_subdag |
| orm_dag.owners = owner |
| orm_dag.is_active = True |
| orm_dag.last_scheduler_run = sync_time |
| session.merge(orm_dag) |
| session.commit() |
| |
| for subdag in self.subdags: |
| subdag.sync_to_db(owner=owner, sync_time=sync_time, session=session) |
| |
| @staticmethod |
| @provide_session |
| def deactivate_unknown_dags(active_dag_ids, session=None): |
| """ |
| Given a list of known DAGs, deactivate any other DAGs that are |
| marked as active in the ORM |
| |
| :param active_dag_ids: list of DAG IDs that are active |
| :type active_dag_ids: list[unicode] |
| :return: None |
| """ |
| |
| if len(active_dag_ids) == 0: |
| return |
| for dag in session.query( |
| DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): |
| dag.is_active = False |
| session.merge(dag) |
| |
| @staticmethod |
| @provide_session |
| def deactivate_stale_dags(expiration_date, session=None): |
| """ |
| Deactivate any DAGs that were last touched by the scheduler before |
| the expiration date. These DAGs were likely deleted. |
| |
| :param expiration_date: set inactive DAGs that were touched before this |
| time |
| :type expiration_date: datetime |
| :return: None |
| """ |
| log = LoggingMixin().log |
| for dag in session.query( |
| DagModel).filter(DagModel.last_scheduler_run < expiration_date, |
| DagModel.is_active).all(): |
| log.info( |
| "Deactivating DAG ID %s since it was last touched by the scheduler at %s", |
| dag.dag_id, dag.last_scheduler_run.isoformat() |
| ) |
| dag.is_active = False |
| session.merge(dag) |
| session.commit() |
| |
| @staticmethod |
| @provide_session |
| def get_num_task_instances(dag_id, task_ids, states=None, session=None): |
| """ |
| Returns the number of task instances in the given DAG. |
| |
| :param session: ORM session |
| :param dag_id: ID of the DAG to get the task concurrency of |
| :type dag_id: unicode |
| :param task_ids: A list of valid task IDs for the given DAG |
| :type task_ids: list[unicode] |
| :param states: A list of states to filter by if supplied |
| :type states: list[state] |
| :return: The number of running tasks |
| :rtype: int |
| """ |
| qry = session.query(func.count(TaskInstance.task_id)).filter( |
| TaskInstance.dag_id == dag_id, |
| TaskInstance.task_id.in_(task_ids)) |
| if states is not None: |
| if None in states: |
| qry = qry.filter(or_( |
| TaskInstance.state.in_(states), |
| TaskInstance.state.is_(None))) |
| else: |
| qry = qry.filter(TaskInstance.state.in_(states)) |
| return qry.scalar() |
| |
| |
| class Chart(Base): |
| __tablename__ = "chart" |
| |
| id = Column(Integer, primary_key=True) |
| label = Column(String(200)) |
| conn_id = Column(String(ID_LEN), nullable=False) |
| user_id = Column(Integer(), ForeignKey('users.id'), nullable=True) |
| chart_type = Column(String(100), default="line") |
| sql_layout = Column(String(50), default="series") |
| sql = Column(Text, default="SELECT series, x, y FROM table") |
| y_log_scale = Column(Boolean) |
| show_datatable = Column(Boolean) |
| show_sql = Column(Boolean, default=True) |
| height = Column(Integer, default=600) |
| default_params = Column(String(5000), default="{}") |
| owner = relationship( |
| "User", cascade=False, cascade_backrefs=False, backref='charts') |
| x_is_date = Column(Boolean, default=True) |
| iteration_no = Column(Integer, default=0) |
| last_modified = Column(DateTime, default=func.now()) |
| |
| def __repr__(self): |
| return self.label |
| |
| |
| class KnownEventType(Base): |
| __tablename__ = "known_event_type" |
| |
| id = Column(Integer, primary_key=True) |
| know_event_type = Column(String(200)) |
| |
| def __repr__(self): |
| return self.know_event_type |
| |
| |
| class KnownEvent(Base): |
| __tablename__ = "known_event" |
| |
| id = Column(Integer, primary_key=True) |
| label = Column(String(200)) |
| start_date = Column(DateTime) |
| end_date = Column(DateTime) |
| user_id = Column(Integer(), ForeignKey('users.id'),) |
| known_event_type_id = Column(Integer(), ForeignKey('known_event_type.id'),) |
| reported_by = relationship( |
| "User", cascade=False, cascade_backrefs=False, backref='known_events') |
| event_type = relationship( |
| "KnownEventType", |
| cascade=False, |
| cascade_backrefs=False, backref='known_events') |
| description = Column(Text) |
| |
| def __repr__(self): |
| return self.label |
| |
| |
| class Variable(Base, LoggingMixin): |
| __tablename__ = "variable" |
| |
| id = Column(Integer, primary_key=True) |
| key = Column(String(ID_LEN), unique=True) |
| _val = Column('val', Text) |
| is_encrypted = Column(Boolean, unique=False, default=False) |
| |
| def __repr__(self): |
| # Hiding the value |
| return '{} : {}'.format(self.key, self._val) |
| |
| def get_val(self): |
| if self._val and self.is_encrypted: |
| try: |
| fernet = get_fernet() |
| except: |
| raise AirflowException( |
| "Can't decrypt _val for key={}, FERNET_KEY configuration \ |
| missing".format(self.key)) |
| try: |
| return fernet.decrypt(bytes(self._val, 'utf-8')).decode() |
| except: |
| raise AirflowException( |
| "Can't decrypt _val for key={}, invalid token or value" |
| .format(self.key)) |
| else: |
| return self._val |
| |
| def set_val(self, value): |
| if value: |
| try: |
| fernet = get_fernet() |
| self._val = fernet.encrypt(bytes(value, 'utf-8')).decode() |
| self.is_encrypted = True |
| except AirflowException: |
| self.log.exception( |
| "Failed to load fernet while encrypting value, using non-encrypted value." |
| ) |
| self._val = value |
| self.is_encrypted = False |
| |
| @declared_attr |
| def val(cls): |
| return synonym('_val', |
| descriptor=property(cls.get_val, cls.set_val)) |
| |
| @classmethod |
| def setdefault(cls, key, default, deserialize_json=False): |
| """ |
| Like a Python builtin dict object, setdefault returns the current value |
| for a key, and if it isn't there, stores the default value and returns it. |
| |
| :param key: Dict key for this Variable |
| :type key: String |
| :param default: Default value to set and return if the variable |
| isn't already in the DB |
| :type default: Mixed |
| :param deserialize_json: Store this as a JSON encoded value in the DB |
| and un-encode it when retrieving a value |
| :return: Mixed |
| """ |
| default_sentinel = object() |
| obj = Variable.get(key, default_var=default_sentinel, deserialize_json=deserialize_json) |
| if obj is default_sentinel: |
| if default is not None: |
| Variable.set(key, default, serialize_json=deserialize_json) |
| return default |
| else: |
| raise ValueError('Default Value must be set') |
| else: |
| return obj |
| |
| @classmethod |
| @provide_session |
| def get(cls, key, default_var=None, deserialize_json=False, session=None): |
| obj = session.query(cls).filter(cls.key == key).first() |
| if obj is None: |
| if default_var is not None: |
| return default_var |
| else: |
| raise KeyError('Variable {} does not exist'.format(key)) |
| else: |
| if deserialize_json: |
| return json.loads(obj.val) |
| else: |
| return obj.val |
| |
| @classmethod |
| @provide_session |
| def set(cls, key, value, serialize_json=False, session=None): |
| |
| if serialize_json: |
| stored_value = json.dumps(value) |
| else: |
| stored_value = value |
| |
| session.query(cls).filter(cls.key == key).delete() |
| session.add(Variable(key=key, val=stored_value)) |
| session.flush() |
| |
| |
| class XCom(Base, LoggingMixin): |
| """ |
| Base class for XCom objects. |
| """ |
| __tablename__ = "xcom" |
| |
| id = Column(Integer, primary_key=True) |
| key = Column(String(512)) |
| value = Column(LargeBinary) |
| timestamp = Column( |
| DateTime, default=func.now(), nullable=False) |
| execution_date = Column(DateTime, nullable=False) |
| |
| # source information |
| task_id = Column(String(ID_LEN), nullable=False) |
| dag_id = Column(String(ID_LEN), nullable=False) |
| |
| __table_args__ = ( |
| Index('idx_xcom_dag_task_date', dag_id, task_id, execution_date, unique=False), |
| ) |
| |
| def __repr__(self): |
| return '<XCom "{key}" ({task_id} @ {execution_date})>'.format( |
| key=self.key, |
| task_id=self.task_id, |
| execution_date=self.execution_date) |
| |
| @classmethod |
| @provide_session |
| def set( |
| cls, |
| key, |
| value, |
| execution_date, |
| task_id, |
| dag_id, |
| enable_pickling=None, |
| session=None): |
| """ |
| Store an XCom value. |
| TODO: "pickling" has been deprecated and JSON is preferred. "pickling" will be |
| removed in Airflow 2.0. :param enable_pickling: If pickling is not enabled, the |
| XCOM value will be parsed as JSON instead. |
| |
| :return: None |
| """ |
| session.expunge_all() |
| |
| if enable_pickling is None: |
| enable_pickling = configuration.getboolean('core', 'enable_xcom_pickling') |
| |
| if enable_pickling: |
| value = pickle.dumps(value) |
| else: |
| try: |
| value = json.dumps(value).encode('UTF-8') |
| except ValueError: |
| log = LoggingMixin().log |
| log.error("Could not serialize the XCOM value into JSON. " |
| "If you are using pickles instead of JSON " |
| "for XCOM, then you need to enable pickle " |
| "support for XCOM in your airflow config.") |
| raise |
| |
| # remove any duplicate XComs |
| session.query(cls).filter( |
| cls.key == key, |
| cls.execution_date == execution_date, |
| cls.task_id == task_id, |
| cls.dag_id == dag_id).delete() |
| |
| session.commit() |
| |
| # insert new XCom |
| session.add(XCom( |
| key=key, |
| value=value, |
| execution_date=execution_date, |
| task_id=task_id, |
| dag_id=dag_id)) |
| |
| session.commit() |
| |
| @classmethod |
| @provide_session |
| def get_one( |
| cls, |
| execution_date, |
| key=None, |
| task_id=None, |
| dag_id=None, |
| include_prior_dates=False, |
| enable_pickling=None, |
| session=None): |
| """ |
| Retrieve an XCom value, optionally meeting certain criteria. |
| TODO: "pickling" has been deprecated and JSON is preferred. "pickling" will be removed in Airflow 2.0. |
| |
| :param enable_pickling: If pickling is not enabled, the XCOM value will be parsed to JSON instead. |
| :return: XCom value |
| """ |
| filters = [] |
| if key: |
| filters.append(cls.key == key) |
| if task_id: |
| filters.append(cls.task_id == task_id) |
| if dag_id: |
| filters.append(cls.dag_id == dag_id) |
| if include_prior_dates: |
| filters.append(cls.execution_date <= execution_date) |
| else: |
| filters.append(cls.execution_date == execution_date) |
| |
| query = ( |
| session.query(cls.value) |
| .filter(and_(*filters)) |
| .order_by(cls.execution_date.desc(), cls.timestamp.desc())) |
| |
| result = query.first() |
| if result: |
| if enable_pickling is None: |
| enable_pickling = configuration.getboolean('core', 'enable_xcom_pickling') |
| |
| if enable_pickling: |
| return pickle.loads(result.value) |
| else: |
| try: |
| return json.loads(result.value.decode('UTF-8')) |
| except ValueError: |
| log = LoggingMixin().log |
| log.error("Could not serialize the XCOM value into JSON. " |
| "If you are using pickles instead of JSON " |
| "for XCOM, then you need to enable pickle " |
| "support for XCOM in your airflow config.") |
| raise |
| |
| @classmethod |
| @provide_session |
| def get_many( |
| cls, |
| execution_date, |
| key=None, |
| task_ids=None, |
| dag_ids=None, |
| include_prior_dates=False, |
| limit=100, |
| enable_pickling=None, |
| session=None): |
| """ |
| Retrieve an XCom value, optionally meeting certain criteria |
| TODO: "pickling" has been deprecated and JSON is preferred. "pickling" will be removed in Airflow 2.0. |
| """ |
| filters = [] |
| if key: |
| filters.append(cls.key == key) |
| if task_ids: |
| filters.append(cls.task_id.in_(as_tuple(task_ids))) |
| if dag_ids: |
| filters.append(cls.dag_id.in_(as_tuple(dag_ids))) |
| if include_prior_dates: |
| filters.append(cls.execution_date <= execution_date) |
| else: |
| filters.append(cls.execution_date == execution_date) |
| |
| query = ( |
| session.query(cls) |
| .filter(and_(*filters)) |
| .order_by(cls.execution_date.desc(), cls.timestamp.desc()) |
| .limit(limit)) |
| results = query.all() |
| if enable_pickling is None: |
| enable_pickling = configuration.getboolean('core', 'enable_xcom_pickling') |
| for result in results: |
| if enable_pickling: |
| result.value = pickle.loads(result.value) |
| else: |
| try: |
| result.value = json.loads(result.value.decode('UTF-8')) |
| except ValueError: |
| log = LoggingMixin().log |
| log.error("Could not serialize the XCOM value into JSON. " |
| "If you are using pickles instead of JSON " |
| "for XCOM, then you need to enable pickle " |
| "support for XCOM in your airflow config.") |
| raise |
| return results |
| |
| @classmethod |
| @provide_session |
| def delete(cls, xcoms, session=None): |
| if isinstance(xcoms, XCom): |
| xcoms = [xcoms] |
| for xcom in xcoms: |
| if not isinstance(xcom, XCom): |
| raise TypeError( |
| 'Expected XCom; received {}'.format(xcom.__class__.__name__) |
| ) |
| session.delete(xcom) |
| session.commit() |
| |
| |
| class DagStat(Base): |
| __tablename__ = "dag_stats" |
| |
| dag_id = Column(String(ID_LEN), primary_key=True) |
| state = Column(String(50), primary_key=True) |
| count = Column(Integer, default=0) |
| dirty = Column(Boolean, default=False) |
| |
| def __init__(self, dag_id, state, count=0, dirty=False): |
| self.dag_id = dag_id |
| self.state = state |
| self.count = count |
| self.dirty = dirty |
| |
| @staticmethod |
| @provide_session |
| def set_dirty(dag_id, session=None): |
| """ |
| :param dag_id: the dag_id to mark dirty |
| :param session: database session |
| :return: |
| """ |
| DagStat.create(dag_id=dag_id, session=session) |
| |
| try: |
| stats = session.query(DagStat).filter( |
| DagStat.dag_id == dag_id |
| ).with_for_update().all() |
| |
| for stat in stats: |
| stat.dirty = True |
| session.commit() |
| except Exception as e: |
| session.rollback() |
| log = LoggingMixin().log |
| log.warning("Could not update dag stats for %s", dag_id) |
| log.exception(e) |
| |
| @staticmethod |
| @provide_session |
| def update(dag_ids=None, dirty_only=True, session=None): |
| """ |
| Updates the stats for dirty/out-of-sync dags |
| |
| :param dag_ids: dag_ids to be updated |
| :type dag_ids: list |
| :param dirty_only: only updated for marked dirty, defaults to True |
| :type dirty_only: bool |
| :param session: db session to use |
| :type session: Session |
| """ |
| try: |
| qry = session.query(DagStat) |
| if dag_ids: |
| qry = qry.filter(DagStat.dag_id.in_(set(dag_ids))) |
| if dirty_only: |
| qry = qry.filter(DagStat.dirty == True) |
| |
| qry = qry.with_for_update().all() |
| |
| ids = set([dag_stat.dag_id for dag_stat in qry]) |
| |
| # avoid querying with an empty IN clause |
| if len(ids) == 0: |
| session.commit() |
| return |
| |
| dagstat_states = set(itertools.product(ids, State.dag_states)) |
| qry = ( |
| session.query(DagRun.dag_id, DagRun.state, func.count('*')) |
| .filter(DagRun.dag_id.in_(ids)) |
| .group_by(DagRun.dag_id, DagRun.state) |
| ) |
| |
| counts = {(dag_id, state): count for dag_id, state, count in qry} |
| for dag_id, state in dagstat_states: |
| count = 0 |
| if (dag_id, state) in counts: |
| count = counts[(dag_id, state)] |
| |
| session.merge( |
| DagStat(dag_id=dag_id, state=state, count=count, dirty=False) |
| ) |
| |
| session.commit() |
| except Exception as e: |
| session.rollback() |
| log = LoggingMixin().log |
| log.warning("Could not update dag stat table") |
| log.exception(e) |
| |
| @staticmethod |
| @provide_session |
| def create(dag_id, session=None): |
| """ |
| Creates the missing states the stats table for the dag specified |
| |
| :param dag_id: dag id of the dag to create stats for |
| :param session: database session |
| :return: |
| """ |
| # unfortunately sqlalchemy does not know upsert |
| qry = session.query(DagStat).filter(DagStat.dag_id == dag_id).all() |
| states = [dag_stat.state for dag_stat in qry] |
| for state in State.dag_states: |
| if state not in states: |
| try: |
| session.merge(DagStat(dag_id=dag_id, state=state)) |
| session.commit() |
| except Exception as e: |
| session.rollback() |
| log = LoggingMixin().log |
| log.warning("Could not create stat record") |
| log.exception(e) |
| |
| |
| class DagRun(Base, LoggingMixin): |
| """ |
| DagRun describes an instance of a Dag. It can be created |
| by the scheduler (for regular runs) or by an external trigger |
| """ |
| __tablename__ = "dag_run" |
| |
| ID_PREFIX = 'scheduled__' |
| ID_FORMAT_PREFIX = ID_PREFIX + '{0}' |
| |
| id = Column(Integer, primary_key=True) |
| dag_id = Column(String(ID_LEN)) |
| execution_date = Column(DateTime, default=func.now()) |
| start_date = Column(DateTime, default=func.now()) |
| end_date = Column(DateTime) |
| _state = Column('state', String(50), default=State.RUNNING) |
| run_id = Column(String(ID_LEN)) |
| external_trigger = Column(Boolean, default=True) |
| conf = Column(PickleType) |
| |
| dag = None |
| |
| __table_args__ = ( |
| Index('dr_run_id', dag_id, run_id, unique=True), |
| ) |
| |
| def __repr__(self): |
| return ( |
| '<DagRun {dag_id} @ {execution_date}: {run_id}, ' |
| 'externally triggered: {external_trigger}>' |
| ).format( |
| dag_id=self.dag_id, |
| execution_date=self.execution_date, |
| run_id=self.run_id, |
| external_trigger=self.external_trigger) |
| |
| def get_state(self): |
| return self._state |
| |
| def set_state(self, state): |
| if self._state != state: |
| self._state = state |
| if self.dag_id is not None: |
| # something really weird goes on here: if you try to close the session |
| # dag runs will end up detached |
| session = settings.Session() |
| DagStat.set_dirty(self.dag_id, session=session) |
| |
| @declared_attr |
| def state(self): |
| return synonym('_state', |
| descriptor=property(self.get_state, self.set_state)) |
| |
| @classmethod |
| def id_for_date(cls, date, prefix=ID_FORMAT_PREFIX): |
| return prefix.format(date.isoformat()[:19]) |
| |
| @provide_session |
| def refresh_from_db(self, session=None): |
| """ |
| Reloads the current dagrun from the database |
| :param session: database session |
| """ |
| DR = DagRun |
| |
| exec_date = func.cast(self.execution_date, DateTime) |
| |
| dr = session.query(DR).filter( |
| DR.dag_id == self.dag_id, |
| func.cast(DR.execution_date, DateTime) == exec_date, |
| DR.run_id == self.run_id |
| ).one() |
| |
| self.id = dr.id |
| self.state = dr.state |
| |
| @staticmethod |
| @provide_session |
| def find(dag_id=None, run_id=None, execution_date=None, |
| state=None, external_trigger=None, no_backfills=False, |
| session=None): |
| """ |
| Returns a set of dag runs for the given search criteria. |
| |
| :param dag_id: the dag_id to find dag runs for |
| :type dag_id: integer, list |
| :param run_id: defines the the run id for this dag run |
| :type run_id: string |
| :param execution_date: the execution date |
| :type execution_date: datetime |
| :param state: the state of the dag run |
| :type state: State |
| :param external_trigger: whether this dag run is externally triggered |
| :type external_trigger: bool |
| :param no_backfills: return no backfills (True), return all (False). |
| Defaults to False |
| :type no_backfills: bool |
| :param session: database session |
| :type session: Session |
| """ |
| DR = DagRun |
| |
| qry = session.query(DR) |
| if dag_id: |
| qry = qry.filter(DR.dag_id == dag_id) |
| if run_id: |
| qry = qry.filter(DR.run_id == run_id) |
| if execution_date: |
| if isinstance(execution_date, list): |
| qry = qry.filter(DR.execution_date.in_(execution_date)) |
| else: |
| qry = qry.filter(DR.execution_date == execution_date) |
| if state: |
| qry = qry.filter(DR.state == state) |
| if external_trigger is not None: |
| qry = qry.filter(DR.external_trigger == external_trigger) |
| if no_backfills: |
| # in order to prevent a circular dependency |
| from airflow.jobs import BackfillJob |
| qry = qry.filter(DR.run_id.notlike(BackfillJob.ID_PREFIX + '%')) |
| |
| dr = qry.order_by(DR.execution_date).all() |
| |
| return dr |
| |
| @provide_session |
| def get_task_instances(self, state=None, session=None): |
| """ |
| Returns the task instances for this dag run |
| """ |
| TI = TaskInstance |
| tis = session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.execution_date == self.execution_date, |
| ) |
| if state: |
| if isinstance(state, six.string_types): |
| tis = tis.filter(TI.state == state) |
| else: |
| # this is required to deal with NULL values |
| if None in state: |
| tis = tis.filter( |
| or_(TI.state.in_(state), |
| TI.state.is_(None)) |
| ) |
| else: |
| tis = tis.filter(TI.state.in_(state)) |
| |
| if self.dag and self.dag.partial: |
| tis = tis.filter(TI.task_id.in_(self.dag.task_ids)) |
| |
| return tis.all() |
| |
| @provide_session |
| def get_task_instance(self, task_id, session=None): |
| """ |
| Returns the task instance specified by task_id for this dag run |
| |
| :param task_id: the task id |
| """ |
| |
| TI = TaskInstance |
| ti = session.query(TI).filter( |
| TI.dag_id == self.dag_id, |
| TI.execution_date == self.execution_date, |
| TI.task_id == task_id |
| ).first() |
| |
| return ti |
| |
| def get_dag(self): |
| """ |
| Returns the Dag associated with this DagRun. |
| |
| :return: DAG |
| """ |
| if not self.dag: |
| raise AirflowException("The DAG (.dag) for {} needs to be set" |
| .format(self)) |
| |
| return self.dag |
| |
| @provide_session |
| def get_previous_dagrun(self, session=None): |
| """The previous DagRun, if there is one""" |
| |
| return session.query(DagRun).filter( |
| DagRun.dag_id == self.dag_id, |
| DagRun.execution_date < self.execution_date |
| ).order_by( |
| DagRun.execution_date.desc() |
| ).first() |
| |
| @provide_session |
| def get_previous_scheduled_dagrun(self, session=None): |
| """The previous, SCHEDULED DagRun, if there is one""" |
| dag = self.get_dag() |
| |
| return session.query(DagRun).filter( |
| DagRun.dag_id == self.dag_id, |
| DagRun.execution_date == dag.previous_schedule(self.execution_date) |
| ).first() |
| |
| @provide_session |
| def update_state(self, session=None): |
| """ |
| Determines the overall state of the DagRun based on the state |
| of its TaskInstances. |
| |
| :return: State |
| """ |
| |
| dag = self.get_dag() |
| |
| tis = self.get_task_instances(session=session) |
| |
| self.log.info("Updating state for %s considering %s task(s)", self, len(tis)) |
| |
| for ti in list(tis): |
| # skip in db? |
| if ti.state == State.REMOVED: |
| tis.remove(ti) |
| else: |
| ti.task = dag.get_task(ti.task_id) |
| |
| # pre-calculate |
| # db is faster |
| start_dttm = datetime.utcnow() |
| unfinished_tasks = self.get_task_instances( |
| state=State.unfinished(), |
| session=session |
| ) |
| none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) |
| none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) |
| # small speed up |
| if unfinished_tasks and none_depends_on_past and none_task_concurrency: |
| # todo: this can actually get pretty slow: one task costs between 0.01-015s |
| no_dependencies_met = True |
| for ut in unfinished_tasks: |
| # We need to flag upstream and check for changes because upstream |
| # failures can result in deadlock false positives |
| old_state = ut.state |
| deps_met = ut.are_dependencies_met( |
| dep_context=DepContext( |
| flag_upstream_failed=True, |
| ignore_in_retry_period=True), |
| session=session) |
| if deps_met or old_state != ut.current_state(session=session): |
| no_dependencies_met = False |
| break |
| |
| duration = (datetime.utcnow() - start_dttm).total_seconds() * 1000 |
| Stats.timing("dagrun.dependency-check.{}.{}". |
| format(self.dag_id, self.execution_date), duration) |
| |
| # future: remove the check on adhoc tasks (=active_tasks) |
| if len(tis) == len(dag.active_tasks): |
| root_ids = [t.task_id for t in dag.roots] |
| roots = [t for t in tis if t.task_id in root_ids] |
| |
| # if all roots finished and at least on failed, the run failed |
| if (not unfinished_tasks and |
| any(r.state in (State.FAILED, State.UPSTREAM_FAILED) for r in roots)): |
| self.log.info('Marking run %s failed', self) |
| self.state = State.FAILED |
| |
| # if all roots succeeded and no unfinished tasks, the run succeeded |
| elif not unfinished_tasks and all(r.state in (State.SUCCESS, State.SKIPPED) |
| for r in roots): |
| self.log.info('Marking run %s successful', self) |
| self.state = State.SUCCESS |
| |
| # if *all tasks* are deadlocked, the run failed |
| elif (unfinished_tasks and none_depends_on_past and |
| none_task_concurrency and no_dependencies_met): |
| self.log.info('Deadlock; marking run %s failed', self) |
| self.state = State.FAILED |
| |
| # finally, if the roots aren't done, the dag is still running |
| else: |
| self.state = State.RUNNING |
| |
| # todo: determine we want to use with_for_update to make sure to lock the run |
| session.merge(self) |
| session.commit() |
| |
| return self.state |
| |
| @provide_session |
| def verify_integrity(self, session=None): |
| """ |
| Verifies the DagRun by checking for removed tasks or tasks that are not in the |
| database yet. It will set state to removed or add the task if required. |
| """ |
| dag = self.get_dag() |
| tis = self.get_task_instances(session=session) |
| |
| # check for removed tasks |
| task_ids = [] |
| for ti in tis: |
| task_ids.append(ti.task_id) |
| try: |
| dag.get_task(ti.task_id) |
| except AirflowException: |
| if self.state is not State.RUNNING and not dag.partial: |
| ti.state = State.REMOVED |
| |
| # check for missing tasks |
| for task in dag.tasks: |
| if task.adhoc: |
| continue |
| |
| if task.task_id not in task_ids: |
| ti = TaskInstance(task, self.execution_date) |
| session.add(ti) |
| |
| session.commit() |
| |
| @staticmethod |
| def get_run(session, dag_id, execution_date): |
| """ |
| :param dag_id: DAG ID |
| :type dag_id: unicode |
| :param execution_date: execution date |
| :type execution_date: datetime |
| :return: DagRun corresponding to the given dag_id and execution date |
| if one exists. None otherwise. |
| :rtype: DagRun |
| """ |
| qry = session.query(DagRun).filter( |
| DagRun.dag_id == dag_id, |
| DagRun.external_trigger == False, |
| DagRun.execution_date == execution_date, |
| ) |
| return qry.first() |
| |
| @property |
| def is_backfill(self): |
| from airflow.jobs import BackfillJob |
| return self.run_id.startswith(BackfillJob.ID_PREFIX) |
| |
| @classmethod |
| @provide_session |
| def get_latest_runs(cls, session): |
| """Returns the latest DagRun for each DAG. """ |
| subquery = ( |
| session |
| .query( |
| cls.dag_id, |
| func.max(cls.execution_date).label('execution_date')) |
| .group_by(cls.dag_id) |
| .subquery() |
| ) |
| dagruns = ( |
| session |
| .query(cls) |
| .join(subquery, |
| and_(cls.dag_id == subquery.c.dag_id, |
| cls.execution_date == subquery.c.execution_date)) |
| .all() |
| ) |
| return dagruns |
| |
| |
| class Pool(Base): |
| __tablename__ = "slot_pool" |
| |
| id = Column(Integer, primary_key=True) |
| pool = Column(String(50), unique=True) |
| slots = Column(Integer, default=0) |
| description = Column(Text) |
| |
| def __repr__(self): |
| return self.pool |
| |
| def to_json(self): |
| return { |
| 'id': self.id, |
| 'pool': self.pool, |
| 'slots': self.slots, |
| 'description': self.description, |
| } |
| |
| @provide_session |
| def used_slots(self, session): |
| """ |
| Returns the number of slots used at the moment |
| """ |
| running = ( |
| session |
| .query(TaskInstance) |
| .filter(TaskInstance.pool == self.pool) |
| .filter(TaskInstance.state == State.RUNNING) |
| .count() |
| ) |
| return running |
| |
| @provide_session |
| def queued_slots(self, session): |
| """ |
| Returns the number of slots used at the moment |
| """ |
| return ( |
| session |
| .query(TaskInstance) |
| .filter(TaskInstance.pool == self.pool) |
| .filter(TaskInstance.state == State.QUEUED) |
| .count() |
| ) |
| |
| @provide_session |
| def open_slots(self, session): |
| """ |
| Returns the number of slots open at the moment |
| """ |
| used_slots = self.used_slots(session=session) |
| queued_slots = self.queued_slots(session=session) |
| return self.slots - used_slots - queued_slots |
| |
| |
| class SlaMiss(Base): |
| """ |
| Model that stores a history of the SLA that have been missed. |
| It is used to keep track of SLA failures over time and to avoid double |
| triggering alert emails. |
| """ |
| __tablename__ = "sla_miss" |
| |
| task_id = Column(String(ID_LEN), primary_key=True) |
| dag_id = Column(String(ID_LEN), primary_key=True) |
| execution_date = Column(DateTime, primary_key=True) |
| email_sent = Column(Boolean, default=False) |
| timestamp = Column(DateTime) |
| description = Column(Text) |
| notification_sent = Column(Boolean, default=False) |
| |
| def __repr__(self): |
| return str(( |
| self.dag_id, self.task_id, self.execution_date.isoformat())) |
| |
| |
| class ImportError(Base): |
| __tablename__ = "import_error" |
| id = Column(Integer, primary_key=True) |
| timestamp = Column(DateTime) |
| filename = Column(String(1024)) |
| stacktrace = Column(Text) |