| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import pytest |
| |
| from airflow.models.xcom_arg import XComArg |
| from airflow.providers.standard.operators.bash import BashOperator |
| from airflow.providers.standard.operators.python import PythonOperator |
| from airflow.utils.types import NOTSET |
| |
| from tests_common.test_utils.db import clear_db_dags, clear_db_runs |
| |
| pytestmark = pytest.mark.db_test |
| |
| |
| VALUE = 42 |
| |
| |
| def assert_is_value(num: int): |
| if num != VALUE: |
| raise Exception("The test has failed") |
| |
| |
| def build_python_op(dag_maker): |
| def f(task_id): |
| return f"OP:{task_id}" |
| |
| with dag_maker(dag_id="test_xcom_dag"): |
| operator = PythonOperator( |
| python_callable=f, |
| task_id="test_xcom_op", |
| do_xcom_push=True, |
| ) |
| dag_maker.create_dagrun() |
| return operator |
| |
| |
| @pytest.fixture(autouse=True) |
| def clear_db(): |
| clear_db_runs() |
| clear_db_dags() |
| |
| |
| class TestXComArgBuild: |
| def test_xcom_ctor(self, dag_maker): |
| python_op = build_python_op(dag_maker) |
| actual = XComArg(python_op, "test_key") |
| assert actual |
| assert actual.operator == python_op |
| assert actual.key == "test_key" |
| # Asserting the overridden __eq__ method |
| assert actual == XComArg(python_op, "test_key") |
| expected_str = ( |
| "{{ task_instance.xcom_pull(task_ids='test_xcom_op', dag_id='test_xcom_dag', key='test_key') }}" |
| ) |
| assert str(actual) == expected_str |
| assert ( |
| f"echo {actual}" == "echo {{ task_instance.xcom_pull(task_ids='test_xcom_op', " |
| "dag_id='test_xcom_dag', key='test_key') }}" |
| ) |
| |
| def test_xcom_key_is_empty_str(self, dag_maker): |
| python_op = build_python_op(dag_maker) |
| actual = XComArg(python_op, key="") |
| assert actual.key == "" |
| assert ( |
| str(actual) == "{{ task_instance.xcom_pull(task_ids='test_xcom_op', " |
| "dag_id='test_xcom_dag', key='') }}" |
| ) |
| |
| def test_set_downstream(self, dag_maker): |
| with dag_maker("test_set_downstream"): |
| op_a = BashOperator(task_id="a", bash_command="echo a") |
| op_b = BashOperator(task_id="b", bash_command="echo b") |
| bash_op1 = BashOperator(task_id="c", bash_command="echo c") |
| bash_op2 = BashOperator(task_id="d", bash_command="echo c") |
| xcom_args_a = XComArg(op_a) |
| xcom_args_b = XComArg(op_b) |
| |
| bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 |
| dag_maker.create_dagrun() |
| assert op_a in bash_op1.downstream_list |
| assert op_b in op_a.downstream_list |
| assert bash_op2 in op_b.downstream_list |
| |
| def test_set_upstream(self, dag_maker): |
| with dag_maker("test_set_upstream"): |
| op_a = BashOperator(task_id="a", bash_command="echo a") |
| op_b = BashOperator(task_id="b", bash_command="echo b") |
| bash_op1 = BashOperator(task_id="c", bash_command="echo c") |
| bash_op2 = BashOperator(task_id="d", bash_command="echo c") |
| xcom_args_a = XComArg(op_a) |
| xcom_args_b = XComArg(op_b) |
| |
| bash_op1 << xcom_args_a << xcom_args_b << bash_op2 |
| dag_maker.create_dagrun() |
| assert op_a in bash_op1.upstream_list |
| assert op_b in op_a.upstream_list |
| assert bash_op2 in op_b.upstream_list |
| |
| def test_xcom_arg_property_of_base_operator(self, dag_maker): |
| with dag_maker("test_xcom_arg_property_of_base_operator"): |
| op_a = BashOperator(task_id="a", bash_command="echo a") |
| dag_maker.create_dagrun() |
| |
| assert op_a.output == XComArg(op_a) |
| |
| def test_xcom_key_getitem_not_str(self, dag_maker): |
| python_op = build_python_op(dag_maker) |
| actual = XComArg(python_op) |
| with pytest.raises(ValueError, match="XComArg only supports str lookup, received int"): |
| actual[1] |
| |
| def test_xcom_key_getitem(self, dag_maker): |
| python_op = build_python_op(dag_maker) |
| actual = XComArg(python_op, key="another_key") |
| assert actual.key == "another_key" |
| actual_new_key = actual["another_key_2"] |
| assert actual_new_key.key == "another_key_2" |
| |
| def test_xcom_not_iterable(self, dag_maker): |
| python_op = build_python_op(dag_maker) |
| actual = XComArg(python_op) |
| with pytest.raises(TypeError) as ctx: |
| list(actual) |
| assert str(ctx.value) == "'XComArg' object is not iterable" |
| |
| |
| @pytest.mark.system |
| class TestXComArgRuntime: |
| def test_xcom_pass_to_op(self, dag_maker): |
| with dag_maker(dag_id="test_xcom_pass_to_op") as dag: |
| operator = PythonOperator( |
| python_callable=lambda: VALUE, |
| task_id="return_value_1", |
| do_xcom_push=True, |
| ) |
| xarg = XComArg(operator) |
| operator2 = PythonOperator( |
| python_callable=assert_is_value, |
| op_args=[xarg], |
| task_id="assert_is_value_1", |
| ) |
| operator >> operator2 |
| dag.test() |
| |
| def test_xcom_push_and_pass(self, dag_maker): |
| def push_xcom_value(key, value, **context): |
| ti = context["task_instance"] |
| ti.xcom_push(key, value) |
| |
| with dag_maker(dag_id="test_xcom_push_and_pass") as dag: |
| op1 = PythonOperator( |
| python_callable=push_xcom_value, |
| task_id="push_xcom_value", |
| op_args=["my_key", VALUE], |
| ) |
| xarg = XComArg(op1, key="my_key") |
| op2 = PythonOperator( |
| python_callable=assert_is_value, |
| task_id="assert_is_value_1", |
| op_args=[xarg], |
| ) |
| op1 >> op2 |
| dag.test() |
| |
| |
| @pytest.mark.parametrize( |
| "fillvalue, expected_results", |
| [ |
| (NOTSET, {("a", 1), ("b", 2), ("c", 3)}), |
| (None, {("a", 1), ("b", 2), ("c", 3), (None, 4)}), |
| ], |
| ) |
| def test_xcom_zip(dag_maker, session, fillvalue, expected_results): |
| results = set() |
| with dag_maker(session=session, serialized=True) as dag: |
| |
| @dag.task |
| def push_letters(): |
| return ["a", "b", "c"] |
| |
| @dag.task |
| def push_numbers(): |
| return [1, 2, 3, 4] |
| |
| @dag.task |
| def pull(value): |
| results.add(value) |
| |
| pull.expand(value=push_letters().zip(push_numbers(), fillvalue=fillvalue)) |
| |
| dr = dag_maker.create_dagrun() |
| |
| # Run "push_letters" and "push_numbers". |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["push_letters", "push_numbers"] |
| for ti in decision.schedulable_tis: |
| dag_maker.run_ti(task_id=ti.task_id, map_index=ti.map_index, dag_run=dr, session=session) |
| session.commit() |
| |
| # Run "pull". |
| decision = dr.task_instance_scheduling_decisions(session=session) |
| assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["pull"] * len(expected_results) |
| for ti in decision.schedulable_tis: |
| dag_maker.run_ti(task_id=ti.task_id, map_index=ti.map_index, dag_run=dr, session=session) |
| |
| assert results == expected_results |