| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import copy |
| import logging |
| import sys |
| import unittest.mock |
| from collections import namedtuple |
| from datetime import date, datetime, timedelta |
| from subprocess import CalledProcessError |
| from typing import Dict, List, Tuple |
| |
| import funcsigs |
| import pytest |
| |
| from airflow.exceptions import AirflowException |
| from airflow.models import DAG, DagRun, TaskInstance as TI |
| from airflow.models.baseoperator import BaseOperator |
| from airflow.models.taskinstance import clear_task_instances, set_current_context |
| from airflow.models.xcom_arg import XComArg |
| from airflow.operators.dummy import DummyOperator |
| from airflow.operators.python import ( |
| BranchPythonOperator, |
| PythonOperator, |
| PythonVirtualenvOperator, |
| ShortCircuitOperator, |
| get_current_context, |
| task as task_decorator, |
| ) |
| from airflow.utils import timezone |
| from airflow.utils.dates import days_ago |
| from airflow.utils.session import create_session |
| from airflow.utils.state import State |
| from airflow.utils.task_group import TaskGroup |
| from airflow.utils.types import DagRunType |
| from tests.test_utils.db import clear_db_runs |
| |
| DEFAULT_DATE = timezone.datetime(2016, 1, 1) |
| END_DATE = timezone.datetime(2016, 1, 2) |
| INTERVAL = timedelta(hours=12) |
| FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) |
| |
| TI_CONTEXT_ENV_VARS = [ |
| 'AIRFLOW_CTX_DAG_ID', |
| 'AIRFLOW_CTX_TASK_ID', |
| 'AIRFLOW_CTX_EXECUTION_DATE', |
| 'AIRFLOW_CTX_DAG_RUN_ID', |
| ] |
| |
| |
| class Call: |
| def __init__(self, *args, **kwargs): |
| self.args = args |
| self.kwargs = kwargs |
| |
| |
| def build_recording_function(calls_collection): |
| """ |
| We can not use a Mock instance as a PythonOperator callable function or some tests fail with a |
| TypeError: Object of type Mock is not JSON serializable |
| Then using this custom function recording custom Call objects for further testing |
| (replacing Mock.assert_called_with assertion method) |
| """ |
| |
| def recording_function(*args, **kwargs): |
| calls_collection.append(Call(*args, **kwargs)) |
| |
| return recording_function |
| |
| |
| class TestPythonBase(unittest.TestCase): |
| """Base test class for TestPythonOperator and TestPythonSensor classes""" |
| |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def setUp(self): |
| super().setUp() |
| self.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}) |
| self.addCleanup(self.dag.clear) |
| self.clear_run() |
| self.addCleanup(self.clear_run) |
| |
| def tearDown(self): |
| super().tearDown() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def clear_run(self): |
| self.run = False |
| |
| def _assert_calls_equal(self, first, second): |
| self.assertIsInstance(first, Call) |
| self.assertIsInstance(second, Call) |
| self.assertTupleEqual(first.args, second.args) |
| # eliminate context (conf, dag_run, task_instance, etc.) |
| test_args = ["an_int", "a_date", "a_templated_string"] |
| first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args} |
| second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args} |
| self.assertDictEqual(first.kwargs, second.kwargs) |
| |
| |
| class TestPythonOperator(TestPythonBase): |
| def do_run(self): |
| self.run = True |
| |
| def is_run(self): |
| return self.run |
| |
| def test_python_operator_run(self): |
| """Tests that the python callable is invoked on task run.""" |
| task = PythonOperator(python_callable=self.do_run, task_id='python_operator', dag=self.dag) |
| self.assertFalse(self.is_run()) |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| self.assertTrue(self.is_run()) |
| |
| def test_python_operator_python_callable_is_callable(self): |
| """Tests that PythonOperator will only instantiate if |
| the python_callable argument is callable.""" |
| not_callable = {} |
| with self.assertRaises(AirflowException): |
| PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag) |
| not_callable = None |
| with self.assertRaises(AirflowException): |
| PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag) |
| |
| def test_python_callable_arguments_are_templatized(self): |
| """Test PythonOperator op_args are templatized""" |
| recorded_calls = [] |
| |
| # Create a named tuple and ensure it is still preserved |
| # after the rendering is done |
| Named = namedtuple('Named', ['var1', 'var2']) |
| named_tuple = Named('{{ ds }}', 'unchanged') |
| |
| task = PythonOperator( |
| task_id='python_operator', |
| # a Mock instance cannot be used as a callable function or test fails with a |
| # TypeError: Object of type Mock is not JSON serializable |
| python_callable=build_recording_function(recorded_calls), |
| op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple], |
| dag=self.dag, |
| ) |
| |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| ds_templated = DEFAULT_DATE.date().isoformat() |
| self.assertEqual(1, len(recorded_calls)) |
| self._assert_calls_equal( |
| recorded_calls[0], |
| Call( |
| 4, |
| date(2019, 1, 1), |
| f"dag {self.dag.dag_id} ran on {ds_templated}.", |
| Named(ds_templated, 'unchanged'), |
| ), |
| ) |
| |
| def test_python_callable_keyword_arguments_are_templatized(self): |
| """Test PythonOperator op_kwargs are templatized""" |
| recorded_calls = [] |
| |
| task = PythonOperator( |
| task_id='python_operator', |
| # a Mock instance cannot be used as a callable function or test fails with a |
| # TypeError: Object of type Mock is not JSON serializable |
| python_callable=build_recording_function(recorded_calls), |
| op_kwargs={ |
| 'an_int': 4, |
| 'a_date': date(2019, 1, 1), |
| 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.", |
| }, |
| dag=self.dag, |
| ) |
| |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| self.assertEqual(1, len(recorded_calls)) |
| self._assert_calls_equal( |
| recorded_calls[0], |
| Call( |
| an_int=4, |
| a_date=date(2019, 1, 1), |
| a_templated_string="dag {} ran on {}.".format( |
| self.dag.dag_id, DEFAULT_DATE.date().isoformat() |
| ), |
| ), |
| ) |
| |
| def test_python_operator_shallow_copy_attr(self): |
| not_callable = lambda x: x |
| original_task = PythonOperator( |
| python_callable=not_callable, |
| task_id='python_operator', |
| op_kwargs={'certain_attrs': ''}, |
| dag=self.dag, |
| ) |
| new_task = copy.deepcopy(original_task) |
| # shallow copy op_kwargs |
| self.assertEqual( |
| id(original_task.op_kwargs['certain_attrs']), id(new_task.op_kwargs['certain_attrs']) |
| ) |
| # shallow copy python_callable |
| self.assertEqual(id(original_task.python_callable), id(new_task.python_callable)) |
| |
| def test_conflicting_kwargs(self): |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| external_trigger=False, |
| ) |
| |
| # dag is not allowed since it is a reserved keyword |
| def func(dag): |
| # An ValueError should be triggered since we're using dag as a |
| # reserved keyword |
| raise RuntimeError(f"Should not be triggered, dag: {dag}") |
| |
| python_operator = PythonOperator( |
| task_id='python_operator', op_args=[1], python_callable=func, dag=self.dag |
| ) |
| |
| with self.assertRaises(ValueError) as context: |
| python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| self.assertTrue('dag' in context.exception, "'dag' not found in the exception") |
| |
| def test_provide_context_does_not_fail(self): |
| """ |
| ensures that provide_context doesn't break dags in 2.0 |
| """ |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| external_trigger=False, |
| ) |
| |
| def func(custom, dag): |
| self.assertEqual(1, custom, "custom should be 1") |
| self.assertIsNotNone(dag, "dag should be set") |
| |
| python_operator = PythonOperator( |
| task_id='python_operator', |
| op_kwargs={'custom': 1}, |
| python_callable=func, |
| provide_context=True, |
| dag=self.dag, |
| ) |
| python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| def test_context_with_conflicting_op_args(self): |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| external_trigger=False, |
| ) |
| |
| def func(custom, dag): |
| self.assertEqual(1, custom, "custom should be 1") |
| self.assertIsNotNone(dag, "dag should be set") |
| |
| python_operator = PythonOperator( |
| task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag |
| ) |
| python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| def test_context_with_kwargs(self): |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| external_trigger=False, |
| ) |
| |
| def func(**context): |
| # check if context is being set |
| self.assertGreater(len(context), 0, "Context has not been injected") |
| |
| python_operator = PythonOperator( |
| task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag |
| ) |
| python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| |
| class TestAirflowTaskDecorator(TestPythonBase): |
| def test_python_operator_python_callable_is_callable(self): |
| """Tests that @task will only instantiate if |
| the python_callable argument is callable.""" |
| not_callable = {} |
| with pytest.raises(AirflowException): |
| task_decorator(not_callable, dag=self.dag) |
| |
| def test_infer_multiple_outputs_using_typing(self): |
| @task_decorator |
| def identity_dict(x: int, y: int) -> Dict[str, int]: |
| return {"x": x, "y": y} |
| |
| assert identity_dict(5, 5).operator.multiple_outputs is True # pylint: disable=maybe-no-member |
| |
| @task_decorator |
| def identity_tuple(x: int, y: int) -> Tuple[int, int]: |
| return x, y |
| |
| assert identity_tuple(5, 5).operator.multiple_outputs is False # pylint: disable=maybe-no-member |
| |
| @task_decorator |
| def identity_int(x: int) -> int: |
| return x |
| |
| assert identity_int(5).operator.multiple_outputs is False # pylint: disable=maybe-no-member |
| |
| @task_decorator |
| def identity_notyping(x: int): |
| return x |
| |
| assert identity_notyping(5).operator.multiple_outputs is False # pylint: disable=maybe-no-member |
| |
| def test_manual_multiple_outputs_false_with_typings(self): |
| @task_decorator(multiple_outputs=False) |
| def identity2(x: int, y: int) -> Dict[int, int]: |
| return (x, y) |
| |
| with self.dag: |
| res = identity2(8, 4) |
| |
| dr = self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL.value, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| res.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member |
| |
| ti = dr.get_task_instances()[0] |
| |
| assert res.operator.multiple_outputs is False # pylint: disable=maybe-no-member |
| assert ti.xcom_pull() == [8, 4] # pylint: disable=maybe-no-member |
| assert ti.xcom_pull(key="return_value_0") is None |
| assert ti.xcom_pull(key="return_value_1") is None |
| |
| def test_multiple_outputs_ignore_typing(self): |
| @task_decorator |
| def identity_tuple(x: int, y: int) -> Tuple[int, int]: |
| return x, y |
| |
| with self.dag: |
| ident = identity_tuple(35, 36) |
| |
| dr = self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL.value, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| ident.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member |
| |
| ti = dr.get_task_instances()[0] |
| |
| assert not ident.operator.multiple_outputs # pylint: disable=maybe-no-member |
| assert ti.xcom_pull() == [35, 36] |
| assert ti.xcom_pull(key="return_value_0") is None |
| assert ti.xcom_pull(key="return_value_1") is None |
| |
| def test_fails_bad_signature(self): |
| """Tests that @task will fail if signature is not binding.""" |
| |
| @task_decorator |
| def add_number(num: int) -> int: |
| return num + 2 |
| |
| with pytest.raises(TypeError): |
| add_number(2, 3) # pylint: disable=too-many-function-args |
| with pytest.raises(TypeError): |
| add_number() # pylint: disable=no-value-for-parameter |
| add_number('test') # pylint: disable=no-value-for-parameter |
| |
| def test_fail_method(self): |
| """Tests that @task will fail if signature is not binding.""" |
| |
| with pytest.raises(AirflowException): |
| |
| class Test: |
| num = 2 |
| |
| @task_decorator |
| def add_number(self, num: int) -> int: |
| return self.num + num |
| |
| Test().add_number(2) |
| |
| def test_fail_multiple_outputs_key_type(self): |
| @task_decorator(multiple_outputs=True) |
| def add_number(num: int): |
| return {2: num} |
| |
| with self.dag: |
| ret = add_number(2) |
| self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| with pytest.raises(AirflowException): |
| # pylint: disable=maybe-no-member |
| ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| def test_fail_multiple_outputs_no_dict(self): |
| @task_decorator(multiple_outputs=True) |
| def add_number(num: int): |
| return num |
| |
| with self.dag: |
| ret = add_number(2) |
| self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| with pytest.raises(AirflowException): |
| # pylint: disable=maybe-no-member |
| ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| def test_python_callable_arguments_are_templatized(self): |
| """Test @task op_args are templatized""" |
| recorded_calls = [] |
| |
| # Create a named tuple and ensure it is still preserved |
| # after the rendering is done |
| Named = namedtuple('Named', ['var1', 'var2']) |
| named_tuple = Named('{{ ds }}', 'unchanged') |
| |
| task = task_decorator( |
| # a Mock instance cannot be used as a callable function or test fails with a |
| # TypeError: Object of type Mock is not JSON serializable |
| build_recording_function(recorded_calls), |
| dag=self.dag, |
| ) |
| ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple) |
| |
| self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member |
| |
| ds_templated = DEFAULT_DATE.date().isoformat() |
| assert len(recorded_calls) == 1 |
| self._assert_calls_equal( |
| recorded_calls[0], |
| Call( |
| 4, |
| date(2019, 1, 1), |
| f"dag {self.dag.dag_id} ran on {ds_templated}.", |
| Named(ds_templated, 'unchanged'), |
| ), |
| ) |
| |
| def test_python_callable_keyword_arguments_are_templatized(self): |
| """Test PythonOperator op_kwargs are templatized""" |
| recorded_calls = [] |
| |
| task = task_decorator( |
| # a Mock instance cannot be used as a callable function or test fails with a |
| # TypeError: Object of type Mock is not JSON serializable |
| build_recording_function(recorded_calls), |
| dag=self.dag, |
| ) |
| ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}.") |
| self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL, |
| execution_date=DEFAULT_DATE, |
| start_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member |
| |
| assert len(recorded_calls) == 1 |
| self._assert_calls_equal( |
| recorded_calls[0], |
| Call( |
| an_int=4, |
| a_date=date(2019, 1, 1), |
| a_templated_string="dag {} ran on {}.".format( |
| self.dag.dag_id, DEFAULT_DATE.date().isoformat() |
| ), |
| ), |
| ) |
| |
| def test_manual_task_id(self): |
| """Test manually seting task_id""" |
| |
| @task_decorator(task_id='some_name') |
| def do_run(): |
| return 4 |
| |
| with self.dag: |
| do_run() |
| assert ['some_name'] == self.dag.task_ids |
| |
| def test_multiple_calls(self): |
| """Test calling task multiple times in a DAG""" |
| |
| @task_decorator |
| def do_run(): |
| return 4 |
| |
| with self.dag: |
| do_run() |
| assert ['do_run'] == self.dag.task_ids |
| do_run_1 = do_run() |
| do_run_2 = do_run() |
| assert ['do_run', 'do_run__1', 'do_run__2'] == self.dag.task_ids |
| |
| assert do_run_1.operator.task_id == 'do_run__1' # pylint: disable=maybe-no-member |
| assert do_run_2.operator.task_id == 'do_run__2' # pylint: disable=maybe-no-member |
| |
| def test_multiple_calls_in_task_group(self): |
| """Test calling task multiple times in a TaskGroup""" |
| |
| @task_decorator |
| def do_run(): |
| return 4 |
| |
| group_id = "KnightsOfNii" |
| with self.dag: |
| with TaskGroup(group_id=group_id): |
| do_run() |
| assert [f"{group_id}.do_run"] == self.dag.task_ids |
| do_run() |
| assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] == self.dag.task_ids |
| |
| assert len(self.dag.task_ids) == 2 |
| |
| def test_call_20(self): |
| """Test calling decorated function 21 times in a DAG""" |
| |
| @task_decorator |
| def __do_run(): |
| return 4 |
| |
| with self.dag: |
| __do_run() |
| for _ in range(20): |
| __do_run() |
| |
| assert self.dag.task_ids[-1] == '__do_run__20' |
| |
| def test_multiple_outputs(self): |
| """Tests pushing multiple outputs as a dictionary""" |
| |
| @task_decorator(multiple_outputs=True) |
| def return_dict(number: int): |
| return {'number': number + 1, '43': 43} |
| |
| test_number = 10 |
| with self.dag: |
| ret = return_dict(test_number) |
| |
| dr = self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member |
| |
| ti = dr.get_task_instances()[0] |
| assert ti.xcom_pull(key='number') == test_number + 1 |
| assert ti.xcom_pull(key='43') == 43 |
| assert ti.xcom_pull() == {'number': test_number + 1, '43': 43} |
| |
| def test_default_args(self): |
| """Test that default_args are captured when calling the function correctly""" |
| |
| @task_decorator |
| def do_run(): |
| return 4 |
| |
| with self.dag: |
| ret = do_run() |
| assert ret.operator.owner == 'airflow' # pylint: disable=maybe-no-member |
| |
| def test_xcom_arg(self): |
| """Tests that returned key in XComArg is returned correctly""" |
| |
| @task_decorator |
| def add_2(number: int): |
| return number + 2 |
| |
| @task_decorator |
| def add_num(number: int, num2: int = 2): |
| return number + num2 |
| |
| test_number = 10 |
| |
| with self.dag: |
| bigger_number = add_2(test_number) |
| ret = add_num(bigger_number, XComArg(bigger_number.operator)) # pylint: disable=maybe-no-member |
| |
| dr = self.dag.create_dagrun( |
| run_id=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| bigger_number.operator.run( # pylint: disable=maybe-no-member |
| start_date=DEFAULT_DATE, end_date=DEFAULT_DATE |
| ) |
| ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member |
| ti_add_num = [ti for ti in dr.get_task_instances() if ti.task_id == 'add_num'][0] |
| assert ti_add_num.xcom_pull(key=ret.key) == (test_number + 2) * 2 # pylint: disable=maybe-no-member |
| |
| def test_dag_task(self): |
| """Tests dag.task property to generate task""" |
| |
| @self.dag.task |
| def add_2(number: int): |
| return number + 2 |
| |
| test_number = 10 |
| res = add_2(test_number) |
| add_2(res) |
| |
| assert 'add_2' in self.dag.task_ids |
| |
| def test_dag_task_multiple_outputs(self): |
| """Tests dag.task property to generate task with multiple outputs""" |
| |
| @self.dag.task(multiple_outputs=True) |
| def add_2(number: int): |
| return {'1': number + 2, '2': 42} |
| |
| test_number = 10 |
| add_2(test_number) |
| add_2(test_number) |
| |
| assert 'add_2' in self.dag.task_ids |
| |
| def test_airflow_task(self): |
| """Tests airflow.task decorator to generate task""" |
| from airflow.decorators import task |
| |
| @task |
| def add_2(number: int): |
| return number + 2 |
| |
| test_number = 10 |
| with self.dag: |
| add_2(test_number) |
| |
| assert 'add_2' in self.dag.task_ids |
| |
| def test_task_documentation(self): |
| """Tests that task_decorator loads doc_md from function doc""" |
| |
| @task_decorator |
| def add_2(number: int): |
| """ |
| Adds 2 to number. |
| """ |
| return number + 2 |
| |
| test_number = 10 |
| with self.dag: |
| ret = add_2(test_number) |
| |
| assert ret.operator.doc_md.strip(), "Adds 2 to number." # pylint: disable=maybe-no-member |
| |
| |
| class TestBranchOperator(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def setUp(self): |
| self.dag = DAG( |
| 'branch_operator_test', |
| default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, |
| schedule_interval=INTERVAL, |
| ) |
| |
| self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) |
| self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) |
| self.branch_3 = None |
| |
| def tearDown(self): |
| super().tearDown() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def test_without_dag_run(self): |
| """This checks the defensive against non existent tasks in a dag run""" |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' |
| ) |
| self.branch_1.set_upstream(branch_op) |
| self.branch_2.set_upstream(branch_op) |
| self.dag.clear() |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| with create_session() as session: |
| tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) |
| |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1': |
| # should exist with state None |
| self.assertEqual(ti.state, State.NONE) |
| elif ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_branch_list_without_dag_run(self): |
| """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2'] |
| ) |
| self.branch_1.set_upstream(branch_op) |
| self.branch_2.set_upstream(branch_op) |
| self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) |
| self.branch_3.set_upstream(branch_op) |
| self.dag.clear() |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| with create_session() as session: |
| tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) |
| |
| expected = { |
| "make_choice": State.SUCCESS, |
| "branch_1": State.NONE, |
| "branch_2": State.NONE, |
| "branch_3": State.SKIPPED, |
| } |
| |
| for ti in tis: |
| if ti.task_id in expected: |
| self.assertEqual(ti.state, expected[ti.task_id]) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_with_dag_run(self): |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' |
| ) |
| |
| self.branch_1.set_upstream(branch_op) |
| self.branch_2.set_upstream(branch_op) |
| self.dag.clear() |
| |
| dr = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1': |
| self.assertEqual(ti.state, State.NONE) |
| elif ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_with_skip_in_branch_downstream_dependencies(self): |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' |
| ) |
| |
| branch_op >> self.branch_1 >> self.branch_2 |
| branch_op >> self.branch_2 |
| self.dag.clear() |
| |
| dr = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1': |
| self.assertEqual(ti.state, State.NONE) |
| elif ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.NONE) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_with_skip_in_branch_downstream_dependencies2(self): |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2' |
| ) |
| |
| branch_op >> self.branch_1 >> self.branch_2 |
| branch_op >> self.branch_2 |
| self.dag.clear() |
| |
| dr = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1': |
| self.assertEqual(ti.state, State.SKIPPED) |
| elif ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.NONE) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_xcom_push(self): |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' |
| ) |
| |
| self.branch_1.set_upstream(branch_op) |
| self.branch_2.set_upstream(branch_op) |
| self.dag.clear() |
| |
| dr = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.xcom_pull(task_ids='make_choice'), 'branch_1') |
| |
| def test_clear_skipped_downstream_task(self): |
| """ |
| After a downstream task is skipped by BranchPythonOperator, clearing the skipped task |
| should not cause it to be executed. |
| """ |
| branch_op = BranchPythonOperator( |
| task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1' |
| ) |
| branches = [self.branch_1, self.branch_2] |
| branch_op >> branches |
| self.dag.clear() |
| |
| dr = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| for task in branches: |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] |
| |
| # Clear the children tasks. |
| with create_session() as session: |
| clear_task_instances(children_tis, session=session, dag=self.dag) |
| |
| # Run the cleared tasks again. |
| for task in branches: |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| # Check if the states are correct after children tasks are cleared. |
| for ti in dr.get_task_instances(): |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| |
| class TestShortCircuitOperator(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def tearDown(self): |
| super().tearDown() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def test_without_dag_run(self): |
| """This checks the defensive against non existent tasks in a dag run""" |
| value = False |
| dag = DAG( |
| 'shortcircuit_operator_test_without_dag_run', |
| default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, |
| schedule_interval=INTERVAL, |
| ) |
| short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) |
| branch_1 = DummyOperator(task_id='branch_1', dag=dag) |
| branch_1.set_upstream(short_op) |
| branch_2 = DummyOperator(task_id='branch_2', dag=dag) |
| branch_2.set_upstream(branch_1) |
| upstream = DummyOperator(task_id='upstream', dag=dag) |
| upstream.set_downstream(short_op) |
| dag.clear() |
| |
| short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| with create_session() as session: |
| tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE) |
| |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'upstream': |
| # should not exist |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| value = True |
| dag.clear() |
| |
| short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'upstream': |
| # should not exist |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.NONE) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_with_dag_run(self): |
| value = False |
| dag = DAG( |
| 'shortcircuit_operator_test_with_dag_run', |
| default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, |
| schedule_interval=INTERVAL, |
| ) |
| short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) |
| branch_1 = DummyOperator(task_id='branch_1', dag=dag) |
| branch_1.set_upstream(short_op) |
| branch_2 = DummyOperator(task_id='branch_2', dag=dag) |
| branch_2.set_upstream(branch_1) |
| upstream = DummyOperator(task_id='upstream', dag=dag) |
| upstream.set_downstream(short_op) |
| dag.clear() |
| |
| logging.error("Tasks %s", dag.tasks) |
| dr = dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| self.assertEqual(len(tis), 4) |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'upstream': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| value = True |
| dag.clear() |
| dr.verify_integrity() |
| upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| self.assertEqual(len(tis), 4) |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'upstream': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': |
| self.assertEqual(ti.state, State.NONE) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| def test_clear_skipped_downstream_task(self): |
| """ |
| After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task |
| should not cause it to be executed. |
| """ |
| dag = DAG( |
| 'shortcircuit_clear_skipped_downstream_task', |
| default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, |
| schedule_interval=INTERVAL, |
| ) |
| short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False) |
| downstream = DummyOperator(task_id='downstream', dag=dag) |
| |
| short_op >> downstream |
| |
| dag.clear() |
| |
| dr = dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dr.get_task_instances() |
| |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'downstream': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| # Clear downstream |
| with create_session() as session: |
| clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag) |
| |
| # Run downstream again |
| downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| # Check if the states are correct. |
| for ti in dr.get_task_instances(): |
| if ti.task_id == 'make_choice': |
| self.assertEqual(ti.state, State.SUCCESS) |
| elif ti.task_id == 'downstream': |
| self.assertEqual(ti.state, State.SKIPPED) |
| else: |
| raise ValueError(f'Invalid task id {ti.task_id} found!') |
| |
| |
| virtualenv_string_args: List[str] = [] |
| |
| |
| class TestPythonVirtualenvOperator(unittest.TestCase): |
| def setUp(self): |
| super().setUp() |
| self.dag = DAG( |
| 'test_dag', |
| default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, |
| schedule_interval=INTERVAL, |
| ) |
| self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| self.addCleanup(self.dag.clear) |
| |
| def tearDown(self): |
| super().tearDown() |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def _run_as_operator(self, fn, python_version=sys.version_info[0], **kwargs): |
| |
| task = PythonVirtualenvOperator( |
| python_callable=fn, python_version=python_version, task_id='task', dag=self.dag, **kwargs |
| ) |
| task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| return task |
| |
| def test_add_dill(self): |
| def f(): |
| pass |
| |
| task = self._run_as_operator(f, use_dill=True, system_site_packages=False) |
| assert 'dill' in task.requirements |
| |
| def test_no_requirements(self): |
| """Tests that the python callable is invoked on task run.""" |
| |
| def f(): |
| pass |
| |
| self._run_as_operator(f) |
| |
| def test_no_system_site_packages(self): |
| def f(): |
| try: |
| import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import |
| except ImportError: |
| return True |
| raise Exception |
| |
| self._run_as_operator(f, system_site_packages=False, requirements=['dill']) |
| |
| def test_system_site_packages(self): |
| def f(): |
| import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import |
| |
| self._run_as_operator(f, requirements=['funcsigs'], system_site_packages=True) |
| |
| def test_with_requirements_pinned(self): |
| self.assertNotEqual('0.4', funcsigs.__version__, 'Please update this string if this fails') |
| |
| def f(): |
| import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported |
| |
| if funcsigs.__version__ != '0.4': |
| raise Exception |
| |
| self._run_as_operator(f, requirements=['funcsigs==0.4']) |
| |
| def test_unpinned_requirements(self): |
| def f(): |
| import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import |
| |
| self._run_as_operator(f, requirements=['funcsigs', 'dill'], system_site_packages=False) |
| |
| def test_range_requirements(self): |
| def f(): |
| import funcsigs # noqa: F401 # pylint: disable=redefined-outer-name,reimported,unused-import |
| |
| self._run_as_operator(f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False) |
| |
| def test_fail(self): |
| def f(): |
| raise Exception |
| |
| with self.assertRaises(CalledProcessError): |
| self._run_as_operator(f) |
| |
| def test_python_2(self): |
| def f(): |
| {}.iteritems() # pylint: disable=no-member |
| |
| self._run_as_operator(f, python_version=2, requirements=['dill']) |
| |
| def test_python_2_7(self): |
| def f(): |
| {}.iteritems() # pylint: disable=no-member |
| return True |
| |
| self._run_as_operator(f, python_version='2.7', requirements=['dill']) |
| |
| def test_python_3(self): |
| def f(): |
| import sys # pylint: disable=reimported,unused-import,redefined-outer-name |
| |
| print(sys.version) |
| try: |
| {}.iteritems() # pylint: disable=no-member |
| except AttributeError: |
| return |
| raise Exception |
| |
| self._run_as_operator(f, python_version=3, use_dill=False, requirements=['dill']) |
| |
| @staticmethod |
| def _invert_python_major_version(): |
| if sys.version_info[0] == 2: |
| return 3 |
| else: |
| return 2 |
| |
| def test_wrong_python_op_args(self): |
| if sys.version_info[0] == 2: |
| version = 3 |
| else: |
| version = 2 |
| |
| def f(): |
| pass |
| |
| with self.assertRaises(AirflowException): |
| self._run_as_operator(f, python_version=version, op_args=[1]) |
| |
| def test_without_dill(self): |
| def f(a): |
| return a |
| |
| self._run_as_operator(f, system_site_packages=False, use_dill=False, op_args=[4]) |
| |
| def test_string_args(self): |
| def f(): |
| global virtualenv_string_args # pylint: disable=global-statement |
| print(virtualenv_string_args) |
| if virtualenv_string_args[0] != virtualenv_string_args[2]: |
| raise Exception |
| |
| self._run_as_operator(f, python_version=self._invert_python_major_version(), string_args=[1, 2, 1]) |
| |
| def test_with_args(self): |
| def f(a, b, c=False, d=False): |
| if a == 0 and b == 1 and c and not d: |
| return True |
| else: |
| raise Exception |
| |
| self._run_as_operator(f, op_args=[0, 1], op_kwargs={'c': True}) |
| |
| def test_return_none(self): |
| def f(): |
| return None |
| |
| self._run_as_operator(f) |
| |
| def test_lambda(self): |
| with self.assertRaises(AirflowException): |
| PythonVirtualenvOperator(python_callable=lambda x: 4, task_id='task', dag=self.dag) |
| |
| def test_nonimported_as_arg(self): |
| def f(_): |
| return None |
| |
| self._run_as_operator(f, op_args=[datetime.utcnow()]) |
| |
| def test_context(self): |
| def f(templates_dict): |
| return templates_dict['ds'] |
| |
| self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) |
| |
| def test_airflow_context(self): |
| def f( |
| # basic |
| ds_nodash, |
| inlets, |
| next_ds, |
| next_ds_nodash, |
| outlets, |
| params, |
| prev_ds, |
| prev_ds_nodash, |
| run_id, |
| task_instance_key_str, |
| test_mode, |
| tomorrow_ds, |
| tomorrow_ds_nodash, |
| ts, |
| ts_nodash, |
| ts_nodash_with_tz, |
| yesterday_ds, |
| yesterday_ds_nodash, |
| # pendulum-specific |
| execution_date, |
| next_execution_date, |
| prev_execution_date, |
| prev_execution_date_success, |
| prev_start_date_success, |
| # airflow-specific |
| macros, |
| conf, |
| dag, |
| dag_run, |
| task, |
| # other |
| **context, |
| ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals |
| pass |
| |
| self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None) |
| |
| def test_pendulum_context(self): |
| def f( |
| # basic |
| ds_nodash, |
| inlets, |
| next_ds, |
| next_ds_nodash, |
| outlets, |
| params, |
| prev_ds, |
| prev_ds_nodash, |
| run_id, |
| task_instance_key_str, |
| test_mode, |
| tomorrow_ds, |
| tomorrow_ds_nodash, |
| ts, |
| ts_nodash, |
| ts_nodash_with_tz, |
| yesterday_ds, |
| yesterday_ds_nodash, |
| # pendulum-specific |
| execution_date, |
| next_execution_date, |
| prev_execution_date, |
| prev_execution_date_success, |
| prev_start_date_success, |
| # other |
| **context, |
| ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals |
| pass |
| |
| self._run_as_operator( |
| f, use_dill=True, system_site_packages=False, requirements=['pendulum', 'lazy_object_proxy'] |
| ) |
| |
| def test_base_context(self): |
| def f( |
| # basic |
| ds_nodash, |
| inlets, |
| next_ds, |
| next_ds_nodash, |
| outlets, |
| params, |
| prev_ds, |
| prev_ds_nodash, |
| run_id, |
| task_instance_key_str, |
| test_mode, |
| tomorrow_ds, |
| tomorrow_ds_nodash, |
| ts, |
| ts_nodash, |
| ts_nodash_with_tz, |
| yesterday_ds, |
| yesterday_ds_nodash, |
| # other |
| **context, |
| ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals |
| pass |
| |
| self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None) |
| |
| |
| DEFAULT_ARGS = { |
| "owner": "test", |
| "depends_on_past": True, |
| "start_date": days_ago(1), |
| "end_date": datetime.today(), |
| "schedule_interval": "@once", |
| "retries": 1, |
| "retry_delay": timedelta(minutes=1), |
| } |
| |
| |
| class TestCurrentContext: |
| def test_current_context_no_context_raise(self): |
| with pytest.raises(AirflowException): |
| get_current_context() |
| |
| def test_current_context_roundtrip(self): |
| example_context = {"Hello": "World"} |
| |
| with set_current_context(example_context): |
| assert get_current_context() == example_context |
| |
| def test_context_removed_after_exit(self): |
| example_context = {"Hello": "World"} |
| |
| with set_current_context(example_context): |
| pass |
| with pytest.raises( |
| AirflowException, |
| ): |
| get_current_context() |
| |
| def test_nested_context(self): |
| """ |
| Nested execution context should be supported in case the user uses multiple context managers. |
| Each time the execute method of an operator is called, we set a new 'current' context. |
| This test verifies that no matter how many contexts are entered - order is preserved |
| """ |
| max_stack_depth = 15 |
| ctx_list = [] |
| for i in range(max_stack_depth): |
| # Create all contexts in ascending order |
| new_context = {"ContextId": i} |
| # Like 15 nested with statements |
| ctx_obj = set_current_context(new_context) |
| ctx_obj.__enter__() # pylint: disable=E1101 |
| ctx_list.append(ctx_obj) |
| for i in reversed(range(max_stack_depth)): |
| # Iterate over contexts in reverse order - stack is LIFO |
| ctx = get_current_context() |
| assert ctx["ContextId"] == i |
| # End of with statement |
| ctx_list[i].__exit__(None, None, None) |
| |
| |
| class MyContextAssertOperator(BaseOperator): |
| def execute(self, context): |
| assert context == get_current_context() |
| |
| |
| def get_all_the_context(**context): |
| current_context = get_current_context() |
| assert context == current_context |
| |
| |
| @pytest.fixture() |
| def clear_db(): |
| clear_db_runs() |
| yield |
| clear_db_runs() |
| |
| |
| @pytest.mark.usefixtures("clear_db") |
| class TestCurrentContextRuntime: |
| def test_context_in_task(self): |
| with DAG(dag_id="assert_context_dag", default_args=DEFAULT_ARGS): |
| op = MyContextAssertOperator(task_id="assert_context") |
| op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) |
| |
| def test_get_context_in_old_style_context_task(self): |
| with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS): |
| op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context") |
| op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) |
| |
| |
| @pytest.mark.parametrize( |
| "choice,expected_states", |
| [ |
| ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]), |
| ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]), |
| ], |
| ) |
| def test_empty_branch(choice, expected_states): |
| """ |
| Tests that BranchPythonOperator handles empty branches properly. |
| """ |
| with DAG( |
| 'test_empty_branch', |
| start_date=DEFAULT_DATE, |
| ) as dag: |
| branch = BranchPythonOperator(task_id='branch', python_callable=lambda: choice) |
| task1 = DummyOperator(task_id='task1') |
| join = DummyOperator(task_id='join', trigger_rule="none_failed_or_skipped") |
| |
| branch >> [task1, join] |
| task1 >> join |
| |
| dag.clear(start_date=DEFAULT_DATE) |
| |
| task_ids = ["branch", "task1", "join"] |
| |
| tis = {} |
| for task_id in task_ids: |
| task_instance = TI(dag.get_task(task_id), execution_date=DEFAULT_DATE) |
| tis[task_id] = task_instance |
| task_instance.run() |
| |
| def get_state(ti): |
| ti.refresh_from_db() |
| return ti.state |
| |
| assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states |