| # -*- coding: utf-8 -*- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import multiprocessing |
| import os |
| import pickle # type: ignore |
| import signal |
| import unittest |
| from datetime import timedelta |
| from time import sleep |
| from unittest import mock |
| |
| import sqlalchemy |
| from dateutil.relativedelta import relativedelta |
| from numpy.testing import assert_array_almost_equal |
| from pendulum import utcnow |
| |
| from airflow import DAG, configuration, exceptions, jobs, settings, utils |
| from airflow.bin import cli |
| from airflow.configuration import AirflowConfigException, conf, run_command |
| from airflow.exceptions import AirflowException |
| from airflow.executors import SequentialExecutor |
| from airflow.hooks.base_hook import BaseHook |
| from airflow.hooks.sqlite_hook import SqliteHook |
| from airflow.models import BaseOperator, Connection, DagBag, DagRun, Pool, TaskFail, TaskInstance, Variable |
| from airflow.operators.bash_operator import BashOperator |
| from airflow.operators.check_operator import CheckOperator, ValueCheckOperator |
| from airflow.operators.dummy_operator import DummyOperator |
| from airflow.operators.python_operator import PythonOperator |
| from airflow.settings import Session |
| from airflow.utils import timezone |
| from airflow.utils.dates import days_ago, infer_time_unit, round_time, scale_time_units |
| from airflow.utils.state import State |
| from airflow.utils.timezone import datetime |
| from tests.test_utils.config import conf_vars |
| |
| DEV_NULL = '/dev/null' |
| TEST_DAG_FOLDER = os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), 'dags') |
| DEFAULT_DATE = datetime(2015, 1, 1) |
| DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() |
| DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] |
| TEST_DAG_ID = 'unit_tests' |
| EXAMPLE_DAG_DEFAULT_DATE = days_ago(2) |
| |
| |
| class OperatorSubclass(BaseOperator): |
| """ |
| An operator to test template substitution |
| """ |
| template_fields = ['some_templated_field'] |
| |
| def __init__(self, some_templated_field, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.some_templated_field = some_templated_field |
| |
| def execute(self, context): |
| pass |
| |
| |
| class TestCore(unittest.TestCase): |
| TEST_SCHEDULE_WITH_NO_PREVIOUS_RUNS_DAG_ID = TEST_DAG_ID + 'test_schedule_dag_no_previous_runs' |
| TEST_SCHEDULE_DAG_FAKE_SCHEDULED_PREVIOUS_DAG_ID = \ |
| TEST_DAG_ID + 'test_schedule_dag_fake_scheduled_previous' |
| TEST_SCHEDULE_DAG_NO_END_DATE_UP_TO_TODAY_ONLY_DAG_ID = \ |
| TEST_DAG_ID + 'test_schedule_dag_no_end_date_up_to_today_only' |
| TEST_SCHEDULE_ONCE_DAG_ID = TEST_DAG_ID + 'test_schedule_dag_once' |
| TEST_SCHEDULE_RELATIVEDELTA_DAG_ID = TEST_DAG_ID + 'test_schedule_dag_relativedelta' |
| TEST_SCHEDULE_START_END_DATES_DAG_ID = TEST_DAG_ID + 'test_schedule_dag_start_end_dates' |
| |
| default_scheduler_args = {"num_runs": 1} |
| |
| def setUp(self): |
| self.dagbag = DagBag( |
| dag_folder=DEV_NULL, include_examples=True) |
| self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} |
| self.dag = DAG(TEST_DAG_ID, default_args=self.args) |
| self.dag_bash = self.dagbag.dags['example_bash_operator'] |
| self.runme_0 = self.dag_bash.get_task('runme_0') |
| self.run_after_loop = self.dag_bash.get_task('run_after_loop') |
| self.run_this_last = self.dag_bash.get_task('run_this_last') |
| |
| def tearDown(self): |
| if os.environ.get('KUBERNETES_VERSION') is not None: |
| return |
| |
| dag_ids_to_clean = [ |
| TEST_DAG_ID, |
| self.TEST_SCHEDULE_WITH_NO_PREVIOUS_RUNS_DAG_ID, |
| self.TEST_SCHEDULE_DAG_FAKE_SCHEDULED_PREVIOUS_DAG_ID, |
| self.TEST_SCHEDULE_DAG_NO_END_DATE_UP_TO_TODAY_ONLY_DAG_ID, |
| self.TEST_SCHEDULE_ONCE_DAG_ID, |
| self.TEST_SCHEDULE_RELATIVEDELTA_DAG_ID, |
| self.TEST_SCHEDULE_START_END_DATES_DAG_ID, |
| ] |
| session = Session() |
| session.query(DagRun).filter( |
| DagRun.dag_id.in_(dag_ids_to_clean)).delete( |
| synchronize_session=False) |
| session.query(TaskInstance).filter( |
| TaskInstance.dag_id.in_(dag_ids_to_clean)).delete( |
| synchronize_session=False) |
| session.query(TaskFail).filter( |
| TaskFail.dag_id.in_(dag_ids_to_clean)).delete( |
| synchronize_session=False) |
| session.commit() |
| session.close() |
| |
| def test_schedule_dag_no_previous_runs(self): |
| """ |
| Tests scheduling a dag with no previous runs |
| """ |
| dag = DAG(self.TEST_SCHEDULE_WITH_NO_PREVIOUS_RUNS_DAG_ID) |
| dag.add_task(BaseOperator( |
| task_id="faketastic", |
| owner='Also fake', |
| start_date=datetime(2015, 1, 2, 0, 0))) |
| |
| dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag) |
| self.assertIsNotNone(dag_run) |
| self.assertEqual(dag.dag_id, dag_run.dag_id) |
| self.assertIsNotNone(dag_run.run_id) |
| self.assertNotEqual('', dag_run.run_id) |
| self.assertEqual( |
| datetime(2015, 1, 2, 0, 0), |
| dag_run.execution_date, |
| msg='dag_run.execution_date did not match expectation: {0}' |
| .format(dag_run.execution_date) |
| ) |
| self.assertEqual(State.RUNNING, dag_run.state) |
| self.assertFalse(dag_run.external_trigger) |
| dag.clear() |
| |
| def test_schedule_dag_relativedelta(self): |
| """ |
| Tests scheduling a dag with a relativedelta schedule_interval |
| """ |
| delta = relativedelta(hours=+1) |
| dag = DAG(self.TEST_SCHEDULE_RELATIVEDELTA_DAG_ID, |
| schedule_interval=delta) |
| dag.add_task(BaseOperator( |
| task_id="faketastic", |
| owner='Also fake', |
| start_date=datetime(2015, 1, 2, 0, 0))) |
| |
| dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag) |
| self.assertIsNotNone(dag_run) |
| self.assertEqual(dag.dag_id, dag_run.dag_id) |
| self.assertIsNotNone(dag_run.run_id) |
| self.assertNotEqual('', dag_run.run_id) |
| self.assertEqual( |
| datetime(2015, 1, 2, 0, 0), |
| dag_run.execution_date, |
| msg='dag_run.execution_date did not match expectation: {0}' |
| .format(dag_run.execution_date) |
| ) |
| self.assertEqual(State.RUNNING, dag_run.state) |
| self.assertFalse(dag_run.external_trigger) |
| dag_run2 = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag) |
| self.assertIsNotNone(dag_run2) |
| self.assertEqual(dag.dag_id, dag_run2.dag_id) |
| self.assertIsNotNone(dag_run2.run_id) |
| self.assertNotEqual('', dag_run2.run_id) |
| self.assertEqual( |
| datetime(2015, 1, 2, 0, 0) + delta, |
| dag_run2.execution_date, |
| msg='dag_run2.execution_date did not match expectation: {0}' |
| .format(dag_run2.execution_date) |
| ) |
| self.assertEqual(State.RUNNING, dag_run2.state) |
| self.assertFalse(dag_run2.external_trigger) |
| dag.clear() |
| |
| def test_schedule_dag_fake_scheduled_previous(self): |
| """ |
| Test scheduling a dag where there is a prior DagRun |
| which has the same run_id as the next run should have |
| """ |
| delta = timedelta(hours=1) |
| |
| dag = DAG(self.TEST_SCHEDULE_DAG_FAKE_SCHEDULED_PREVIOUS_DAG_ID, |
| schedule_interval=delta, |
| start_date=DEFAULT_DATE) |
| dag.add_task(BaseOperator( |
| task_id="faketastic", |
| owner='Also fake', |
| start_date=DEFAULT_DATE)) |
| |
| scheduler = jobs.SchedulerJob(**self.default_scheduler_args) |
| dag.create_dagrun(run_id=DagRun.id_for_date(DEFAULT_DATE), |
| execution_date=DEFAULT_DATE, |
| state=State.SUCCESS, |
| external_trigger=True) |
| dag_run = scheduler.create_dag_run(dag) |
| self.assertIsNotNone(dag_run) |
| self.assertEqual(dag.dag_id, dag_run.dag_id) |
| self.assertIsNotNone(dag_run.run_id) |
| self.assertNotEqual('', dag_run.run_id) |
| self.assertEqual( |
| DEFAULT_DATE + delta, |
| dag_run.execution_date, |
| msg='dag_run.execution_date did not match expectation: {0}' |
| .format(dag_run.execution_date) |
| ) |
| self.assertEqual(State.RUNNING, dag_run.state) |
| self.assertFalse(dag_run.external_trigger) |
| |
| def test_schedule_dag_once(self): |
| """ |
| Tests scheduling a dag scheduled for @once - should be scheduled the first time |
| it is called, and not scheduled the second. |
| """ |
| dag = DAG(self.TEST_SCHEDULE_ONCE_DAG_ID) |
| dag.schedule_interval = '@once' |
| dag.add_task(BaseOperator( |
| task_id="faketastic", |
| owner='Also fake', |
| start_date=datetime(2015, 1, 2, 0, 0))) |
| dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag) |
| dag_run2 = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag) |
| |
| self.assertIsNotNone(dag_run) |
| self.assertIsNone(dag_run2) |
| dag.clear() |
| |
| def test_fractional_seconds(self): |
| """ |
| Tests if fractional seconds are stored in the database |
| """ |
| dag = DAG(TEST_DAG_ID + 'test_fractional_seconds') |
| dag.schedule_interval = '@once' |
| dag.add_task(BaseOperator( |
| task_id="faketastic", |
| owner='Also fake', |
| start_date=datetime(2015, 1, 2, 0, 0))) |
| |
| start_date = timezone.utcnow() |
| |
| run = dag.create_dagrun( |
| run_id='test_' + start_date.isoformat(), |
| execution_date=start_date, |
| start_date=start_date, |
| state=State.RUNNING, |
| external_trigger=False |
| ) |
| |
| run.refresh_from_db() |
| |
| self.assertEqual(start_date, run.execution_date, |
| "dag run execution_date loses precision") |
| self.assertEqual(start_date, run.start_date, |
| "dag run start_date loses precision ") |
| |
| def test_schedule_dag_start_end_dates(self): |
| """ |
| Tests that an attempt to schedule a task after the Dag's end_date |
| does not succeed. |
| """ |
| delta = timedelta(hours=1) |
| runs = 3 |
| start_date = DEFAULT_DATE |
| end_date = start_date + (runs - 1) * delta |
| |
| dag = DAG(self.TEST_SCHEDULE_START_END_DATES_DAG_ID, |
| start_date=start_date, |
| end_date=end_date, |
| schedule_interval=delta) |
| dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) |
| |
| # Create and schedule the dag runs |
| dag_runs = [] |
| scheduler = jobs.SchedulerJob(**self.default_scheduler_args) |
| for _ in range(runs): |
| dag_runs.append(scheduler.create_dag_run(dag)) |
| |
| additional_dag_run = scheduler.create_dag_run(dag) |
| |
| for dag_run in dag_runs: |
| self.assertIsNotNone(dag_run) |
| |
| self.assertIsNone(additional_dag_run) |
| |
| def test_schedule_dag_no_end_date_up_to_today_only(self): |
| """ |
| Tests that a Dag created without an end_date can only be scheduled up |
| to and including the current datetime. |
| |
| For example, if today is 2016-01-01 and we are scheduling from a |
| start_date of 2015-01-01, only jobs up to, but not including |
| 2016-01-01 should be scheduled. |
| """ |
| session = settings.Session() |
| delta = timedelta(days=1) |
| now = utcnow() |
| start_date = now.subtract(weeks=1) |
| |
| runs = (now - start_date).days |
| |
| dag = DAG(self.TEST_SCHEDULE_DAG_NO_END_DATE_UP_TO_TODAY_ONLY_DAG_ID, |
| start_date=start_date, |
| schedule_interval=delta) |
| dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) |
| |
| dag_runs = [] |
| scheduler = jobs.SchedulerJob(**self.default_scheduler_args) |
| for _ in range(runs): |
| dag_run = scheduler.create_dag_run(dag) |
| dag_runs.append(dag_run) |
| |
| # Mark the DagRun as complete |
| dag_run.state = State.SUCCESS |
| session.merge(dag_run) |
| session.commit() |
| |
| # Attempt to schedule an additional dag run (for 2016-01-01) |
| additional_dag_run = scheduler.create_dag_run(dag) |
| |
| for dag_run in dag_runs: |
| self.assertIsNotNone(dag_run) |
| |
| self.assertIsNone(additional_dag_run) |
| |
| def test_confirm_unittest_mod(self): |
| self.assertTrue(conf.get('core', 'unit_test_mode')) |
| |
| def test_pickling(self): |
| dp = self.dag.pickle() |
| self.assertEqual(dp.pickle.dag_id, self.dag.dag_id) |
| |
| def test_rich_comparison_ops(self): |
| |
| class DAGsubclass(DAG): |
| pass |
| |
| dag_eq = DAG(TEST_DAG_ID, default_args=self.args) |
| |
| dag_diff_load_time = DAG(TEST_DAG_ID, default_args=self.args) |
| dag_diff_name = DAG(TEST_DAG_ID + '_neq', default_args=self.args) |
| |
| dag_subclass = DAGsubclass(TEST_DAG_ID, default_args=self.args) |
| dag_subclass_diff_name = DAGsubclass( |
| TEST_DAG_ID + '2', default_args=self.args) |
| |
| for d in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]: |
| d.last_loaded = self.dag.last_loaded |
| |
| # test identity equality |
| self.assertEqual(self.dag, self.dag) |
| |
| # test dag (in)equality based on _comps |
| self.assertEqual(dag_eq, self.dag) |
| self.assertNotEqual(dag_diff_name, self.dag) |
| self.assertNotEqual(dag_diff_load_time, self.dag) |
| |
| # test dag inequality based on type even if _comps happen to match |
| self.assertNotEqual(dag_subclass, self.dag) |
| |
| # a dag should equal an unpickled version of itself |
| d = pickle.dumps(self.dag) |
| self.assertEqual(pickle.loads(d), self.dag) |
| |
| # dags are ordered based on dag_id no matter what the type is |
| self.assertLess(self.dag, dag_diff_name) |
| self.assertGreater(self.dag, dag_diff_load_time) |
| self.assertLess(self.dag, dag_subclass_diff_name) |
| |
| # greater than should have been created automatically by functools |
| self.assertGreater(dag_diff_name, self.dag) |
| |
| # hashes are non-random and match equality |
| self.assertEqual(hash(self.dag), hash(self.dag)) |
| self.assertEqual(hash(dag_eq), hash(self.dag)) |
| self.assertNotEqual(hash(dag_diff_name), hash(self.dag)) |
| self.assertNotEqual(hash(dag_subclass), hash(self.dag)) |
| |
| def test_check_operators(self): |
| |
| conn_id = "sqlite_default" |
| |
| captainHook = BaseHook.get_hook(conn_id=conn_id) |
| captainHook.run("CREATE TABLE operator_test_table (a, b)") |
| captainHook.run("insert into operator_test_table values (1,2)") |
| |
| t = CheckOperator( |
| task_id='check', |
| sql="select count(*) from operator_test_table", |
| conn_id=conn_id, |
| dag=self.dag) |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| t = ValueCheckOperator( |
| task_id='value_check', |
| pass_value=95, |
| tolerance=0.1, |
| conn_id=conn_id, |
| sql="SELECT 100", |
| dag=self.dag) |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| captainHook.run("drop table operator_test_table") |
| |
| def test_clear_api(self): |
| task = self.dag_bash.tasks[0] |
| task.clear( |
| start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, |
| upstream=True, downstream=True) |
| ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) |
| ti.are_dependents_done() |
| |
| def test_illegal_args(self): |
| """ |
| Tests that Operators reject illegal arguments |
| """ |
| msg = 'Invalid arguments were passed to BashOperator ' |
| '(task_id: test_illegal_args).' |
| with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}): |
| with self.assertWarns(PendingDeprecationWarning) as warning: |
| BashOperator( |
| task_id='test_illegal_args', |
| bash_command='echo success', |
| dag=self.dag, |
| illegal_argument_1234='hello?') |
| assert any(msg in str(w) for w in warning.warnings) |
| |
| def test_illegal_args_forbidden(self): |
| """ |
| Tests that operators raise exceptions on illegal arguments when |
| illegal arguments are not allowed. |
| """ |
| with self.assertRaises(AirflowException) as ctx: |
| BashOperator( |
| task_id='test_illegal_args', |
| bash_command='echo success', |
| dag=self.dag, |
| illegal_argument_1234='hello?') |
| self.assertIn( |
| ('Invalid arguments were passed to BashOperator ' |
| '(task_id: test_illegal_args).'), |
| str(ctx.exception)) |
| |
| def test_bash_operator(self): |
| t = BashOperator( |
| task_id='test_bash_operator', |
| bash_command="echo success", |
| dag=self.dag) |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_bash_operator_multi_byte_output(self): |
| t = BashOperator( |
| task_id='test_multi_byte_bash_operator', |
| bash_command="echo \u2600", |
| dag=self.dag, |
| output_encoding='utf-8') |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_bash_operator_kill(self): |
| import psutil |
| sleep_time = "100%d" % os.getpid() |
| t = BashOperator( |
| task_id='test_bash_operator_kill', |
| execution_timeout=timedelta(seconds=1), |
| bash_command="/bin/bash -c 'sleep %s'" % sleep_time, |
| dag=self.dag) |
| self.assertRaises( |
| exceptions.AirflowTaskTimeout, |
| t.run, |
| start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| sleep(2) |
| pid = -1 |
| for proc in psutil.process_iter(): |
| if proc.cmdline() == ['sleep', sleep_time]: |
| pid = proc.pid |
| if pid != -1: |
| os.kill(pid, signal.SIGTERM) |
| self.fail("BashOperator's subprocess still running after stopping on timeout!") |
| |
| def test_on_failure_callback(self): |
| # Annoying workaround for nonlocal not existing in python 2 |
| data = {'called': False} |
| |
| def check_failure(context, test_case=self): |
| data['called'] = True |
| error = context.get('exception') |
| test_case.assertIsInstance(error, AirflowException) |
| |
| t = BashOperator( |
| task_id='check_on_failure_callback', |
| bash_command="exit 1", |
| dag=self.dag, |
| on_failure_callback=check_failure) |
| self.assertRaises( |
| exceptions.AirflowException, |
| t.run, |
| start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| self.assertTrue(data['called']) |
| |
| def test_dryrun(self): |
| t = BashOperator( |
| task_id='test_dryrun', |
| bash_command="echo success", |
| dag=self.dag) |
| t.dry_run() |
| |
| def test_sqlite(self): |
| import airflow.operators.sqlite_operator |
| t = airflow.operators.sqlite_operator.SqliteOperator( |
| task_id='time_sqlite', |
| sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", |
| dag=self.dag) |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_timeout(self): |
| t = PythonOperator( |
| task_id='test_timeout', |
| execution_timeout=timedelta(seconds=1), |
| python_callable=lambda: sleep(5), |
| dag=self.dag) |
| self.assertRaises( |
| exceptions.AirflowTaskTimeout, |
| t.run, |
| start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_python_op(self): |
| def test_py_op(templates_dict, ds, **kwargs): |
| if not templates_dict['ds'] == ds: |
| raise Exception("failure") |
| |
| t = PythonOperator( |
| task_id='test_py_op', |
| python_callable=test_py_op, |
| templates_dict={'ds': "{{ ds }}"}, |
| dag=self.dag) |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_complex_template(self): |
| def verify_templated_field(context): |
| self.assertEqual(context['ti'].task.some_templated_field['bar'][1], |
| context['ds']) |
| |
| t = OperatorSubclass( |
| task_id='test_complex_template', |
| some_templated_field={ |
| 'foo': '123', |
| 'bar': ['baz', '{{ ds }}'] |
| }, |
| dag=self.dag) |
| t.execute = verify_templated_field |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_template_with_variable(self): |
| """ |
| Test the availability of variables in templates |
| """ |
| val = { |
| 'test_value': 'a test value' |
| } |
| Variable.set("a_variable", val['test_value']) |
| |
| def verify_templated_field(context): |
| self.assertEqual(context['ti'].task.some_templated_field, |
| val['test_value']) |
| |
| t = OperatorSubclass( |
| task_id='test_complex_template', |
| some_templated_field='{{ var.value.a_variable }}', |
| dag=self.dag) |
| t.execute = verify_templated_field |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_template_with_json_variable(self): |
| """ |
| Test the availability of variables (serialized as JSON) in templates |
| """ |
| val = { |
| 'test_value': {'foo': 'bar', 'obj': {'v1': 'yes', 'v2': 'no'}} |
| } |
| Variable.set("a_variable", val['test_value'], serialize_json=True) |
| |
| def verify_templated_field(context): |
| self.assertEqual(context['ti'].task.some_templated_field, |
| val['test_value']['obj']['v2']) |
| |
| t = OperatorSubclass( |
| task_id='test_complex_template', |
| some_templated_field='{{ var.json.a_variable.obj.v2 }}', |
| dag=self.dag) |
| t.execute = verify_templated_field |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_template_with_json_variable_as_value(self): |
| """ |
| Test the availability of variables (serialized as JSON) in templates, but |
| accessed as a value |
| """ |
| val = { |
| 'test_value': {'foo': 'bar'} |
| } |
| Variable.set("a_variable", val['test_value'], serialize_json=True) |
| |
| def verify_templated_field(context): |
| self.assertEqual(context['ti'].task.some_templated_field, |
| '{\n "foo": "bar"\n}') |
| |
| t = OperatorSubclass( |
| task_id='test_complex_template', |
| some_templated_field='{{ var.value.a_variable }}', |
| dag=self.dag) |
| t.execute = verify_templated_field |
| t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| |
| def test_template_non_bool(self): |
| """ |
| Test templates can handle objects with no sense of truthiness |
| """ |
| |
| class NonBoolObject: |
| def __len__(self): |
| return NotImplemented |
| |
| def __bool__(self): |
| return NotImplemented |
| |
| t = OperatorSubclass( |
| task_id='test_bad_template_obj', |
| some_templated_field=NonBoolObject(), |
| dag=self.dag) |
| t.resolve_template_files() |
| |
| def test_task_get_template(self): |
| TI = TaskInstance |
| ti = TI( |
| task=self.runme_0, execution_date=DEFAULT_DATE) |
| ti.dag = self.dag_bash |
| ti.run(ignore_ti_state=True) |
| context = ti.get_template_context() |
| |
| # DEFAULT DATE is 2015-01-01 |
| self.assertEqual(context['ds'], '2015-01-01') |
| self.assertEqual(context['ds_nodash'], '20150101') |
| |
| # next_ds is 2015-01-02 as the dag interval is daily |
| self.assertEqual(context['next_ds'], '2015-01-02') |
| self.assertEqual(context['next_ds_nodash'], '20150102') |
| |
| # prev_ds is 2014-12-31 as the dag interval is daily |
| self.assertEqual(context['prev_ds'], '2014-12-31') |
| self.assertEqual(context['prev_ds_nodash'], '20141231') |
| |
| self.assertEqual(context['ts'], '2015-01-01T00:00:00+00:00') |
| self.assertEqual(context['ts_nodash'], '20150101T000000') |
| self.assertEqual(context['ts_nodash_with_tz'], '20150101T000000+0000') |
| |
| self.assertEqual(context['yesterday_ds'], '2014-12-31') |
| self.assertEqual(context['yesterday_ds_nodash'], '20141231') |
| |
| self.assertEqual(context['tomorrow_ds'], '2015-01-02') |
| self.assertEqual(context['tomorrow_ds_nodash'], '20150102') |
| |
| def test_local_task_job(self): |
| TI = TaskInstance |
| ti = TI( |
| task=self.runme_0, execution_date=DEFAULT_DATE) |
| job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True) |
| job.run() |
| |
| def test_raw_job(self): |
| TI = TaskInstance |
| ti = TI( |
| task=self.runme_0, execution_date=DEFAULT_DATE) |
| ti.dag = self.dag_bash |
| ti.run(ignore_ti_state=True) |
| |
| def test_variable_set_get_round_trip(self): |
| Variable.set("tested_var_set_id", "Monday morning breakfast") |
| self.assertEqual("Monday morning breakfast", Variable.get("tested_var_set_id")) |
| |
| def test_variable_set_get_round_trip_json(self): |
| value = {"a": 17, "b": 47} |
| Variable.set("tested_var_set_id", value, serialize_json=True) |
| self.assertEqual(value, Variable.get("tested_var_set_id", deserialize_json=True)) |
| |
| def test_get_non_existing_var_should_return_default(self): |
| default_value = "some default val" |
| self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", |
| default_var=default_value)) |
| |
| def test_get_non_existing_var_should_raise_key_error(self): |
| with self.assertRaises(KeyError): |
| Variable.get("thisIdDoesNotExist") |
| |
| def test_get_non_existing_var_with_none_default_should_return_none(self): |
| self.assertIsNone(Variable.get("thisIdDoesNotExist", default_var=None)) |
| |
| def test_get_non_existing_var_should_not_deserialize_json_default(self): |
| default_value = "}{ this is a non JSON default }{" |
| self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", |
| default_var=default_value, |
| deserialize_json=True)) |
| |
| def test_variable_setdefault_round_trip(self): |
| key = "tested_var_setdefault_1_id" |
| value = "Monday morning breakfast in Paris" |
| Variable.setdefault(key, value) |
| self.assertEqual(value, Variable.get(key)) |
| |
| def test_variable_setdefault_round_trip_json(self): |
| key = "tested_var_setdefault_2_id" |
| value = {"city": 'Paris', "Happiness": True} |
| Variable.setdefault(key, value, deserialize_json=True) |
| self.assertEqual(value, Variable.get(key, deserialize_json=True)) |
| |
| def test_variable_setdefault_existing_json(self): |
| key = "tested_var_setdefault_2_id" |
| value = {"city": 'Paris', "Happiness": True} |
| Variable.set(key, value, serialize_json=True) |
| val = Variable.setdefault(key, value, deserialize_json=True) |
| # Check the returned value, and the stored value are handled correctly. |
| self.assertEqual(value, val) |
| self.assertEqual(value, Variable.get(key, deserialize_json=True)) |
| |
| def test_variable_delete(self): |
| key = "tested_var_delete" |
| value = "to be deleted" |
| |
| # No-op if the variable doesn't exist |
| Variable.delete(key) |
| with self.assertRaises(KeyError): |
| Variable.get(key) |
| |
| # Set the variable |
| Variable.set(key, value) |
| self.assertEqual(value, Variable.get(key)) |
| |
| # Delete the variable |
| Variable.delete(key) |
| with self.assertRaises(KeyError): |
| Variable.get(key) |
| |
| def test_parameterized_config_gen(self): |
| |
| cfg = configuration.parameterized_config(configuration.DEFAULT_CONFIG) |
| |
| # making sure some basic building blocks are present: |
| self.assertIn("[core]", cfg) |
| self.assertIn("dags_folder", cfg) |
| self.assertIn("sql_alchemy_conn", cfg) |
| self.assertIn("fernet_key", cfg) |
| |
| # making sure replacement actually happened |
| self.assertNotIn("{AIRFLOW_HOME}", cfg) |
| self.assertNotIn("{FERNET_KEY}", cfg) |
| |
| def test_config_use_original_when_original_and_fallback_are_present(self): |
| self.assertTrue(conf.has_option("core", "FERNET_KEY")) |
| self.assertFalse(conf.has_option("core", "FERNET_KEY_CMD")) |
| |
| FERNET_KEY = conf.get('core', 'FERNET_KEY') |
| |
| with conf_vars({('core', 'FERNET_KEY_CMD'): 'printf HELLO'}): |
| FALLBACK_FERNET_KEY = conf.get( |
| "core", |
| "FERNET_KEY" |
| ) |
| |
| self.assertEqual(FERNET_KEY, FALLBACK_FERNET_KEY) |
| |
| def test_config_throw_error_when_original_and_fallback_is_absent(self): |
| self.assertTrue(conf.has_option("core", "FERNET_KEY")) |
| self.assertFalse(conf.has_option("core", "FERNET_KEY_CMD")) |
| |
| with conf_vars({('core', 'fernet_key'): None}): |
| with self.assertRaises(AirflowConfigException) as cm: |
| conf.get("core", "FERNET_KEY") |
| |
| exception = str(cm.exception) |
| message = "section/key [core/fernet_key] not found in config" |
| self.assertEqual(message, exception) |
| |
| def test_config_override_original_when_non_empty_envvar_is_provided(self): |
| key = "AIRFLOW__CORE__FERNET_KEY" |
| value = "some value" |
| |
| with mock.patch.dict('os.environ', {key: value}): |
| FERNET_KEY = conf.get('core', 'FERNET_KEY') |
| |
| self.assertEqual(value, FERNET_KEY) |
| |
| def test_config_override_original_when_empty_envvar_is_provided(self): |
| key = "AIRFLOW__CORE__FERNET_KEY" |
| value = "" |
| |
| with mock.patch.dict('os.environ', {key: value}): |
| FERNET_KEY = conf.get('core', 'FERNET_KEY') |
| |
| self.assertEqual(value, FERNET_KEY) |
| |
| def test_round_time(self): |
| |
| rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1)) |
| self.assertEqual(datetime(2015, 1, 1, 0, 0), rt1) |
| |
| rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1)) |
| self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2) |
| |
| rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime( |
| 2015, 9, 14, 0, 0)) |
| self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3) |
| |
| rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime( |
| 2015, 9, 14, 0, 0)) |
| self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4) |
| |
| rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime( |
| 2015, 9, 14, 0, 0)) |
| self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5) |
| |
| rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime( |
| 2015, 9, 14, 0, 0)) |
| self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6) |
| |
| def test_infer_time_unit(self): |
| |
| self.assertEqual('minutes', infer_time_unit([130, 5400, 10])) |
| |
| self.assertEqual('seconds', infer_time_unit([110, 50, 10, 100])) |
| |
| self.assertEqual('hours', infer_time_unit([100000, 50000, 10000, 20000])) |
| |
| self.assertEqual('days', infer_time_unit([200000, 100000])) |
| |
| def test_scale_time_units(self): |
| |
| # use assert_almost_equal from numpy.testing since we are comparing |
| # floating point arrays |
| arr1 = scale_time_units([130, 5400, 10], 'minutes') |
| assert_array_almost_equal(arr1, [2.167, 90.0, 0.167], decimal=3) |
| |
| arr2 = scale_time_units([110, 50, 10, 100], 'seconds') |
| assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3) |
| |
| arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours') |
| assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556], |
| decimal=3) |
| |
| arr4 = scale_time_units([200000, 100000], 'days') |
| assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3) |
| |
| def test_bad_trigger_rule(self): |
| with self.assertRaises(AirflowException): |
| DummyOperator( |
| task_id='test_bad_trigger', |
| trigger_rule="non_existent", |
| dag=self.dag) |
| |
| def test_terminate_task(self): |
| """If a task instance's db state get deleted, it should fail""" |
| TI = TaskInstance |
| dag = self.dagbag.dags.get('test_utils') |
| task = dag.task_dict.get('sleeps_forever') |
| |
| ti = TI(task=task, execution_date=DEFAULT_DATE) |
| job = jobs.LocalTaskJob( |
| task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) |
| |
| # Running task instance asynchronously |
| p = multiprocessing.Process(target=job.run) |
| p.start() |
| sleep(5) |
| settings.engine.dispose() |
| session = settings.Session() |
| ti.refresh_from_db(session=session) |
| # making sure it's actually running |
| self.assertEqual(State.RUNNING, ti.state) |
| ti = session.query(TI).filter_by( |
| dag_id=task.dag_id, |
| task_id=task.task_id, |
| execution_date=DEFAULT_DATE |
| ).one() |
| |
| # deleting the instance should result in a failure |
| session.delete(ti) |
| session.commit() |
| # waiting for the async task to finish |
| p.join() |
| |
| # making sure that the task ended up as failed |
| ti.refresh_from_db(session=session) |
| self.assertEqual(State.FAILED, ti.state) |
| session.close() |
| |
| def test_task_fail_duration(self): |
| """If a task fails, the duration should be recorded in TaskFail""" |
| |
| p = BashOperator( |
| task_id='pass_sleepy', |
| bash_command='sleep 3', |
| dag=self.dag) |
| f = BashOperator( |
| task_id='fail_sleepy', |
| bash_command='sleep 5', |
| execution_timeout=timedelta(seconds=3), |
| retry_delay=timedelta(seconds=0), |
| dag=self.dag) |
| session = settings.Session() |
| try: |
| p.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| except Exception: |
| pass |
| try: |
| f.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) |
| except Exception: |
| pass |
| p_fails = session.query(TaskFail).filter_by( |
| task_id='pass_sleepy', |
| dag_id=self.dag.dag_id, |
| execution_date=DEFAULT_DATE).all() |
| f_fails = session.query(TaskFail).filter_by( |
| task_id='fail_sleepy', |
| dag_id=self.dag.dag_id, |
| execution_date=DEFAULT_DATE).all() |
| |
| self.assertEqual(0, len(p_fails)) |
| self.assertEqual(1, len(f_fails)) |
| self.assertGreaterEqual(sum([f.duration for f in f_fails]), 3) |
| |
| def test_run_command(self): |
| write = r'sys.stdout.buffer.write("\u1000foo".encode("utf8"))' |
| |
| cmd = 'import sys; {0}; sys.stdout.flush()'.format(write) |
| |
| self.assertEqual(run_command("python -c '{0}'".format(cmd)), '\u1000foo') |
| |
| self.assertEqual(run_command('echo "foo bar"'), 'foo bar\n') |
| self.assertRaises(AirflowConfigException, run_command, 'bash -c "exit 1"') |
| |
| def test_externally_triggered_dagrun(self): |
| TI = TaskInstance |
| |
| # Create the dagrun between two "scheduled" execution dates of the DAG |
| EXECUTION_DATE = DEFAULT_DATE + timedelta(days=2) |
| EXECUTION_DS = EXECUTION_DATE.strftime('%Y-%m-%d') |
| EXECUTION_DS_NODASH = EXECUTION_DS.replace('-', '') |
| |
| dag = DAG( |
| TEST_DAG_ID, |
| default_args=self.args, |
| schedule_interval=timedelta(weeks=1), |
| start_date=DEFAULT_DATE) |
| task = DummyOperator(task_id='test_externally_triggered_dag_context', |
| dag=dag) |
| dag.create_dagrun(run_id=DagRun.id_for_date(EXECUTION_DATE), |
| execution_date=EXECUTION_DATE, |
| state=State.RUNNING, |
| external_trigger=True) |
| task.run( |
| start_date=EXECUTION_DATE, end_date=EXECUTION_DATE) |
| |
| ti = TI(task=task, execution_date=EXECUTION_DATE) |
| context = ti.get_template_context() |
| |
| # next_ds/prev_ds should be the execution date for manually triggered runs |
| self.assertEqual(context['next_ds'], EXECUTION_DS) |
| self.assertEqual(context['next_ds_nodash'], EXECUTION_DS_NODASH) |
| |
| self.assertEqual(context['prev_ds'], EXECUTION_DS) |
| self.assertEqual(context['prev_ds_nodash'], EXECUTION_DS_NODASH) |
| |
| |
| class TestCli(unittest.TestCase): |
| |
| TEST_USER1_EMAIL = 'test-user1@example.com' |
| TEST_USER2_EMAIL = 'test-user2@example.com' |
| |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls._cleanup() |
| |
| def setUp(self): |
| super().setUp() |
| from airflow.www import app as application |
| self.app, self.appbuilder = application.create_app(session=Session, testing=True) |
| self.app.config['TESTING'] = True |
| |
| self.parser = cli.CLIFactory.get_parser() |
| self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) |
| settings.configure_orm() |
| self.session = Session |
| |
| def tearDown(self): |
| self._cleanup(session=self.session) |
| for email in [self.TEST_USER1_EMAIL, self.TEST_USER2_EMAIL]: |
| test_user = self.appbuilder.sm.find_user(email=email) |
| if test_user: |
| self.appbuilder.sm.del_register_user(test_user) |
| for role_name in ['FakeTeamA', 'FakeTeamB']: |
| if self.appbuilder.sm.find_role(role_name): |
| self.appbuilder.sm.delete_role(role_name) |
| |
| super().tearDown() |
| |
| @staticmethod |
| def _cleanup(session=None): |
| if session is None: |
| session = Session() |
| |
| session.query(Pool).delete() |
| session.query(Variable).delete() |
| session.commit() |
| session.close() |
| |
| |
| class TestConnection(unittest.TestCase): |
| def setUp(self): |
| utils.db.initdb() |
| |
| @mock.patch.dict('os.environ', { |
| 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', |
| }) |
| def test_using_env_var(self): |
| c = SqliteHook.get_connection(conn_id='test_uri') |
| self.assertEqual('ec2.compute.com', c.host) |
| self.assertEqual('the_database', c.schema) |
| self.assertEqual('username', c.login) |
| self.assertEqual('password', c.password) |
| self.assertEqual(5432, c.port) |
| |
| @mock.patch.dict('os.environ', { |
| 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', |
| }) |
| def test_using_unix_socket_env_var(self): |
| c = SqliteHook.get_connection(conn_id='test_uri_no_creds') |
| self.assertEqual('ec2.compute.com', c.host) |
| self.assertEqual('the_database', c.schema) |
| self.assertIsNone(c.login) |
| self.assertIsNone(c.password) |
| self.assertIsNone(c.port) |
| |
| def test_param_setup(self): |
| c = Connection(conn_id='local_mysql', conn_type='mysql', |
| host='localhost', login='airflow', |
| password='airflow', schema='airflow') |
| self.assertEqual('localhost', c.host) |
| self.assertEqual('airflow', c.schema) |
| self.assertEqual('airflow', c.login) |
| self.assertEqual('airflow', c.password) |
| self.assertIsNone(c.port) |
| |
| def test_env_var_priority(self): |
| c = SqliteHook.get_connection(conn_id='airflow_db') |
| self.assertNotEqual('ec2.compute.com', c.host) |
| |
| with mock.patch.dict('os.environ', { |
| 'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database', |
| }): |
| c = SqliteHook.get_connection(conn_id='airflow_db') |
| self.assertEqual('ec2.compute.com', c.host) |
| self.assertEqual('the_database', c.schema) |
| self.assertEqual('username', c.login) |
| self.assertEqual('password', c.password) |
| self.assertEqual(5432, c.port) |
| |
| @mock.patch.dict('os.environ', { |
| 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', |
| 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', |
| }) |
| def test_dbapi_get_uri(self): |
| conn = BaseHook.get_connection(conn_id='test_uri') |
| hook = conn.get_hook() |
| self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', hook.get_uri()) |
| conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds') |
| hook2 = conn2.get_hook() |
| self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri()) |
| |
| @mock.patch.dict('os.environ', { |
| 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', |
| 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', |
| }) |
| def test_dbapi_get_sqlalchemy_engine(self): |
| conn = BaseHook.get_connection(conn_id='test_uri') |
| hook = conn.get_hook() |
| engine = hook.get_sqlalchemy_engine() |
| self.assertIsInstance(engine, sqlalchemy.engine.Engine) |
| self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url)) |
| |
| @mock.patch.dict('os.environ', { |
| 'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database', |
| 'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database', |
| }) |
| def test_get_connections_env_var(self): |
| conns = SqliteHook.get_connections(conn_id='test_uri') |
| assert len(conns) == 1 |
| assert conns[0].host == 'ec2.compute.com' |
| assert conns[0].schema == 'the_database' |
| assert conns[0].login == 'username' |
| assert conns[0].password == 'password' |
| assert conns[0].port == 5432 |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |