blob: a211a9324d287352e2725a880c24e49bfa454762 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datetime
import operator
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
from airflow.configuration import conf
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.test_utils.config import conf_vars
class CustomXCom(BaseXCom):
orm_deserialize_value = mock.Mock()
@pytest.fixture(autouse=True)
def reset_db():
"""Reset XCom entries."""
with create_session() as session:
session.query(DagRun).delete()
session.query(XCom).delete()
@pytest.fixture()
def task_instance_factory(request, session: Session):
def func(*, dag_id, task_id, execution_date):
run_id = DagRun.generate_run_id(DagRunType.SCHEDULED, execution_date)
run = DagRun(
dag_id=dag_id,
run_type=DagRunType.SCHEDULED,
run_id=run_id,
execution_date=execution_date,
)
session.add(run)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
session.commit()
def cleanup_database():
# This should also clear task instances by cascading.
session.query(DagRun).filter_by(id=run.id).delete()
session.commit()
request.addfinalizer(cleanup_database)
return ti
return func
@pytest.fixture()
def task_instance(task_instance_factory):
return task_instance_factory(
dag_id="dag",
task_id="task_1",
execution_date=timezone.datetime(2021, 12, 3, 4, 56),
)
@pytest.fixture()
def task_instances(session, task_instance):
ti2 = TaskInstance(EmptyOperator(task_id="task_2"), run_id=task_instance.run_id)
ti2.dag_id = task_instance.dag_id
session.add(ti2)
session.commit()
return task_instance, ti2 # ti2 will be cleaned up automatically with the DAG run.
class TestXCom:
@conf_vars({("core", "xcom_backend"): "tests.models.test_xcom.CustomXCom"})
def test_resolve_xcom_class(self):
cls = resolve_xcom_backend()
assert issubclass(cls, CustomXCom)
@conf_vars({("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"})
def test_resolve_xcom_class_fallback_to_basexcom(self):
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"
@conf_vars({("core", "enable_xcom_pickling"): "False"})
@conf_vars({("core", "xcom_backend"): "to be removed"})
def test_resolve_xcom_class_fallback_to_basexcom_no_config(self):
conf.remove_option("core", "xcom_backend")
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"
def test_xcom_deserialize_with_json_to_pickle_switch(self, task_instance, session):
ti_key = TaskInstanceKey(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
)
with conf_vars({("core", "enable_xcom_pickling"): "False"}):
XCom.set(
key="xcom_test3",
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
with conf_vars({("core", "enable_xcom_pickling"): "True"}):
ret_value = XCom.get_value(key="xcom_test3", ti_key=ti_key, session=session)
assert ret_value == {"key": "value"}
def test_xcom_deserialize_with_pickle_to_json_switch(self, task_instance, session):
with conf_vars({("core", "enable_xcom_pickling"): "True"}):
XCom.set(
key="xcom_test3",
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
with conf_vars({("core", "enable_xcom_pickling"): "False"}):
ret_value = XCom.get_one(
key="xcom_test3",
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
assert ret_value == {"key": "value"}
@conf_vars({("core", "xcom_enable_pickling"): "False"})
def test_xcom_disable_pickle_type_fail_on_non_json(self, task_instance, session):
class PickleRce:
def __reduce__(self):
return os.system, ("ls -alt",)
with pytest.raises(TypeError):
XCom.set(
key="xcom_test3",
value=PickleRce(),
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
@mock.patch("airflow.models.xcom.XCom.orm_deserialize_value")
def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize):
instance = BaseXCom(
key="key",
value="value",
timestamp=timezone.utcnow(),
execution_date=timezone.utcnow(),
task_id="task_id",
dag_id="dag_id",
)
instance.init_on_load()
mock_orm_deserialize.assert_called_once_with()
@conf_vars({("core", "xcom_backend"): "tests.models.test_xcom.CustomXCom"})
def test_get_one_custom_backend_no_use_orm_deserialize_value(self, task_instance, session):
"""Test that XCom.get_one does not call orm_deserialize_value"""
XCom = resolve_xcom_backend()
XCom.set(
key=XCOM_RETURN_KEY,
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
value = XCom.get_one(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
assert value == {"key": "value"}
XCom.orm_deserialize_value.assert_not_called()
@conf_vars({("core", "enable_xcom_pickling"): "False"})
@mock.patch("airflow.models.xcom.conf.getimport")
def test_set_serialize_call_old_signature(self, get_import, task_instance):
"""
When XCom.serialize_value takes only param ``value``, other kwargs should be ignored.
"""
serialize_watcher = MagicMock()
class OldSignatureXCom(BaseXCom):
@staticmethod
def serialize_value(value, **kwargs):
serialize_watcher(value=value, **kwargs)
return json.dumps(value).encode("utf-8")
get_import.return_value = OldSignatureXCom
XCom = resolve_xcom_backend()
XCom.set(
key=XCOM_RETURN_KEY,
value={"my_xcom_key": "my_xcom_value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
)
serialize_watcher.assert_called_once_with(value={"my_xcom_key": "my_xcom_value"})
@conf_vars({("core", "enable_xcom_pickling"): "False"})
@mock.patch("airflow.models.xcom.conf.getimport")
def test_set_serialize_call_current_signature(self, get_import, task_instance):
"""
When XCom.serialize_value includes params execution_date, key, dag_id, task_id and run_id,
then XCom.set should pass all of them.
"""
serialize_watcher = MagicMock()
class CurrentSignatureXCom(BaseXCom):
@staticmethod
def serialize_value(
value,
key=None,
dag_id=None,
task_id=None,
run_id=None,
map_index=None,
):
serialize_watcher(
value=value,
key=key,
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
map_index=map_index,
)
return json.dumps(value).encode("utf-8")
get_import.return_value = CurrentSignatureXCom
XCom = resolve_xcom_backend()
XCom.set(
key=XCOM_RETURN_KEY,
value={"my_xcom_key": "my_xcom_value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
map_index=-1,
)
serialize_watcher.assert_called_once_with(
key=XCOM_RETURN_KEY,
value={"my_xcom_key": "my_xcom_value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
map_index=-1,
)
@pytest.fixture(
params=[
pytest.param("true", id="enable_xcom_pickling=true"),
pytest.param("false", id="enable_xcom_pickling=false"),
],
)
def setup_xcom_pickling(request):
with conf_vars({("core", "enable_xcom_pickling"): str(request.param)}):
yield
@pytest.fixture()
def push_simple_json_xcom(session):
def func(*, ti: TaskInstance, key: str, value):
return XCom.set(
key=key,
value=value,
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
session=session,
)
return func
@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComGet:
@pytest.fixture()
def setup_for_xcom_get_one(self, task_instance, push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": "value"})
@pytest.mark.usefixtures("setup_for_xcom_get_one")
def test_xcom_get_one(self, session, task_instance):
stored_value = XCom.get_one(
key="xcom_1",
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
assert stored_value == {"key": "value"}
@pytest.mark.usefixtures("setup_for_xcom_get_one")
def test_xcom_get_one_with_execution_date(self, session, task_instance):
with pytest.deprecated_call():
stored_value = XCom.get_one(
key="xcom_1",
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
execution_date=task_instance.execution_date,
session=session,
)
assert stored_value == {"key": "value"}
@pytest.fixture()
def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simple_json_xcom):
date1 = timezone.datetime(2021, 12, 3, 4, 56)
ti1 = task_instance_factory(dag_id="dag", execution_date=date1, task_id="task_1")
ti2 = task_instance_factory(
dag_id="dag",
execution_date=date1 + datetime.timedelta(days=1),
task_id="task_1",
)
# The earlier run pushes an XCom, but not the later run, but the later
# run can get this earlier XCom with ``include_prior_dates``.
push_simple_json_xcom(ti=ti1, key="xcom_1", value={"key": "value"})
return ti1, ti2
def test_xcom_get_one_from_prior_date(self, session, tis_for_xcom_get_one_from_prior_date):
_, ti2 = tis_for_xcom_get_one_from_prior_date
retrieved_value = XCom.get_one(
run_id=ti2.run_id,
key="xcom_1",
task_id="task_1",
dag_id="dag",
include_prior_dates=True,
session=session,
)
assert retrieved_value == {"key": "value"}
def test_xcom_get_one_from_prior_with_execution_date(
self,
session,
tis_for_xcom_get_one_from_prior_date,
):
_, ti2 = tis_for_xcom_get_one_from_prior_date
with pytest.deprecated_call():
retrieved_value = XCom.get_one(
execution_date=ti2.execution_date,
key="xcom_1",
task_id="task_1",
dag_id="dag",
include_prior_dates=True,
session=session,
)
assert retrieved_value == {"key": "value"}
@pytest.fixture()
def setup_for_xcom_get_many_single_argument_value(self, task_instance, push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": "value"})
@pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value")
def test_xcom_get_many_single_argument_value(self, session, task_instance):
stored_xcoms = XCom.get_many(
key="xcom_1",
dag_ids=task_instance.dag_id,
task_ids=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
).all()
assert len(stored_xcoms) == 1
assert stored_xcoms[0].key == "xcom_1"
assert stored_xcoms[0].value == {"key": "value"}
@pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value")
def test_xcom_get_many_single_argument_value_with_execution_date(self, session, task_instance):
with pytest.deprecated_call():
stored_xcoms = XCom.get_many(
execution_date=task_instance.execution_date,
key="xcom_1",
dag_ids=task_instance.dag_id,
task_ids=task_instance.task_id,
session=session,
).all()
assert len(stored_xcoms) == 1
assert stored_xcoms[0].key == "xcom_1"
assert stored_xcoms[0].value == {"key": "value"}
@pytest.fixture()
def setup_for_xcom_get_many_multiple_tasks(self, task_instances, push_simple_json_xcom):
ti1, ti2 = task_instances
push_simple_json_xcom(ti=ti1, key="xcom_1", value={"key1": "value1"})
push_simple_json_xcom(ti=ti2, key="xcom_1", value={"key2": "value2"})
@pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks")
def test_xcom_get_many_multiple_tasks(self, session, task_instance):
stored_xcoms = XCom.get_many(
key="xcom_1",
dag_ids=task_instance.dag_id,
task_ids=["task_1", "task_2"],
run_id=task_instance.run_id,
session=session,
)
sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))]
assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}]
@pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks")
def test_xcom_get_many_multiple_tasks_with_execution_date(self, session, task_instance):
with pytest.deprecated_call():
stored_xcoms = XCom.get_many(
execution_date=task_instance.execution_date,
key="xcom_1",
dag_ids=task_instance.dag_id,
task_ids=["task_1", "task_2"],
session=session,
)
sorted_values = [x.value for x in sorted(stored_xcoms, key=operator.attrgetter("task_id"))]
assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}]
@pytest.fixture()
def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory, push_simple_json_xcom):
date1 = timezone.datetime(2021, 12, 3, 4, 56)
date2 = date1 + datetime.timedelta(days=1)
ti1 = task_instance_factory(dag_id="dag", task_id="task_1", execution_date=date1)
ti2 = task_instance_factory(dag_id="dag", task_id="task_1", execution_date=date2)
push_simple_json_xcom(ti=ti1, key="xcom_1", value={"key1": "value1"})
push_simple_json_xcom(ti=ti2, key="xcom_1", value={"key2": "value2"})
return ti1, ti2
def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_from_prior_dates):
ti1, ti2 = tis_for_xcom_get_many_from_prior_dates
stored_xcoms = XCom.get_many(
run_id=ti2.run_id,
key="xcom_1",
dag_ids="dag",
task_ids="task_1",
include_prior_dates=True,
session=session,
)
# The retrieved XComs should be ordered by logical date, latest first.
assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}]
assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date]
def test_xcom_get_many_from_prior_dates_with_execution_date(
self,
session,
tis_for_xcom_get_many_from_prior_dates,
):
ti1, ti2 = tis_for_xcom_get_many_from_prior_dates
with pytest.deprecated_call():
stored_xcoms = XCom.get_many(
execution_date=ti2.execution_date,
key="xcom_1",
dag_ids="dag",
task_ids="task_1",
include_prior_dates=True,
session=session,
)
# The retrieved XComs should be ordered by logical date, latest first.
assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}]
assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date]
@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComSet:
def test_xcom_set(self, session, task_instance):
XCom.set(
key="xcom_1",
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
stored_xcoms = session.query(XCom).all()
assert stored_xcoms[0].key == "xcom_1"
assert stored_xcoms[0].value == {"key": "value"}
assert stored_xcoms[0].dag_id == "dag"
assert stored_xcoms[0].task_id == "task_1"
assert stored_xcoms[0].execution_date == task_instance.execution_date
def test_xcom_set_with_execution_date(self, session, task_instance):
with pytest.deprecated_call():
XCom.set(
key="xcom_1",
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
execution_date=task_instance.execution_date,
session=session,
)
stored_xcoms = session.query(XCom).all()
assert stored_xcoms[0].key == "xcom_1"
assert stored_xcoms[0].value == {"key": "value"}
assert stored_xcoms[0].dag_id == "dag"
assert stored_xcoms[0].task_id == "task_1"
assert stored_xcoms[0].execution_date == task_instance.execution_date
@pytest.fixture()
def setup_for_xcom_set_again_replace(self, task_instance, push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key1": "value1"})
@pytest.mark.usefixtures("setup_for_xcom_set_again_replace")
def test_xcom_set_again_replace(self, session, task_instance):
assert session.query(XCom).one().value == {"key1": "value1"}
XCom.set(
key="xcom_1",
value={"key2": "value2"},
dag_id=task_instance.dag_id,
task_id="task_1",
run_id=task_instance.run_id,
session=session,
)
assert session.query(XCom).one().value == {"key2": "value2"}
@pytest.mark.usefixtures("setup_for_xcom_set_again_replace")
def test_xcom_set_again_replace_with_execution_date(self, session, task_instance):
assert session.query(XCom).one().value == {"key1": "value1"}
with pytest.deprecated_call():
XCom.set(
key="xcom_1",
value={"key2": "value2"},
dag_id=task_instance.dag_id,
task_id="task_1",
execution_date=task_instance.execution_date,
session=session,
)
assert session.query(XCom).one().value == {"key2": "value2"}
@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComClear:
@pytest.fixture()
def setup_for_xcom_clear(self, task_instance, push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": "value"})
@pytest.mark.usefixtures("setup_for_xcom_clear")
def test_xcom_clear(self, session, task_instance):
assert session.query(XCom).count() == 1
XCom.clear(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
assert session.query(XCom).count() == 0
@pytest.mark.usefixtures("setup_for_xcom_clear")
def test_xcom_clear_with_execution_date(self, session, task_instance):
assert session.query(XCom).count() == 1
with pytest.deprecated_call():
XCom.clear(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
execution_date=task_instance.execution_date,
session=session,
)
assert session.query(XCom).count() == 0
@pytest.mark.usefixtures("setup_for_xcom_clear")
def test_xcom_clear_different_run(self, session, task_instance):
XCom.clear(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id="different_run",
session=session,
)
assert session.query(XCom).count() == 1
@pytest.mark.usefixtures("setup_for_xcom_clear")
def test_xcom_clear_different_execution_date(self, session, task_instance):
with pytest.deprecated_call():
XCom.clear(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
execution_date=timezone.utcnow(),
session=session,
)
assert session.query(XCom).count() == 1