| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import datetime |
| import unittest |
| |
| from airflow.models import DAG, DagRun, TaskInstance as TI |
| from airflow.operators.branch import BaseBranchOperator |
| from airflow.operators.dummy import DummyOperator |
| from airflow.utils import timezone |
| from airflow.utils.session import create_session |
| from airflow.utils.state import State |
| from airflow.utils.types import DagRunType |
| |
| DEFAULT_DATE = timezone.datetime(2016, 1, 1) |
| INTERVAL = datetime.timedelta(hours=12) |
| |
| |
| class ChooseBranchOne(BaseBranchOperator): |
| def choose_branch(self, context): |
| return 'branch_1' |
| |
| |
| class ChooseBranchOneTwo(BaseBranchOperator): |
| def choose_branch(self, context): |
| return ['branch_1', 'branch_2'] |
| |
| |
| class TestBranchOperator(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def setUp(self): |
| self.dag = DAG( |
| 'branch_operator_test', |
| default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, |
| schedule_interval=INTERVAL, |
| ) |
| |
| self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) |
| self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) |
| self.branch_3 = None |
| self.branch_op = None |
| |
| def tearDown(self): |
| super().tearDown() |
| |
| with create_session() as session: |
| session.query(DagRun).delete() |
| session.query(TI).delete() |
| |
| def test_without_dag_run(self): |
| """This checks the defensive against non existent tasks in a dag run""" |
| self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) |
| self.branch_1.set_upstream(self.branch_op) |
| self.branch_2.set_upstream(self.branch_op) |
| self.dag.clear() |
| |
| self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| with create_session() as session: |
| tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) |
| |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| assert ti.state == State.SUCCESS |
| elif ti.task_id == 'branch_1': |
| # should exist with state None |
| assert ti.state == State.NONE |
| elif ti.task_id == 'branch_2': |
| assert ti.state == State.SKIPPED |
| else: |
| raise Exception |
| |
| def test_branch_list_without_dag_run(self): |
| """This checks if the BranchOperator supports branching off to a list of tasks.""" |
| self.branch_op = ChooseBranchOneTwo(task_id='make_choice', dag=self.dag) |
| self.branch_1.set_upstream(self.branch_op) |
| self.branch_2.set_upstream(self.branch_op) |
| self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) |
| self.branch_3.set_upstream(self.branch_op) |
| self.dag.clear() |
| |
| self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| with create_session() as session: |
| tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) |
| |
| expected = { |
| "make_choice": State.SUCCESS, |
| "branch_1": State.NONE, |
| "branch_2": State.NONE, |
| "branch_3": State.SKIPPED, |
| } |
| |
| for ti in tis: |
| if ti.task_id in expected: |
| assert ti.state == expected[ti.task_id] |
| else: |
| raise Exception |
| |
| def test_with_dag_run(self): |
| self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) |
| self.branch_1.set_upstream(self.branch_op) |
| self.branch_2.set_upstream(self.branch_op) |
| self.dag.clear() |
| |
| dagrun = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dagrun.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| assert ti.state == State.SUCCESS |
| elif ti.task_id == 'branch_1': |
| assert ti.state == State.NONE |
| elif ti.task_id == 'branch_2': |
| assert ti.state == State.SKIPPED |
| else: |
| raise Exception |
| |
| def test_with_skip_in_branch_downstream_dependencies(self): |
| self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) |
| self.branch_op >> self.branch_1 >> self.branch_2 |
| self.branch_op >> self.branch_2 |
| self.dag.clear() |
| |
| dagrun = self.dag.create_dagrun( |
| run_type=DagRunType.MANUAL, |
| start_date=timezone.utcnow(), |
| execution_date=DEFAULT_DATE, |
| state=State.RUNNING, |
| ) |
| |
| self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) |
| |
| tis = dagrun.get_task_instances() |
| for ti in tis: |
| if ti.task_id == 'make_choice': |
| assert ti.state == State.SUCCESS |
| elif ti.task_id == 'branch_1': |
| assert ti.state == State.NONE |
| elif ti.task_id == 'branch_2': |
| assert ti.state == State.NONE |
| else: |
| raise Exception |