[v3-0-test] Deserialize response of `get_all` when we call `XCom.get_all` (#53020) (#53102)
(cherry picked from commit bdc9cd115d103614159ea7492db8ff16607c6959)
Co-authored-by: Amogh Desai <amoghrajesh1999@gmail.com>
diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py
index 770dbf5..82df8d1 100644
--- a/task-sdk/src/airflow/sdk/bases/xcom.py
+++ b/task-sdk/src/airflow/sdk/bases/xcom.py
@@ -290,6 +290,7 @@
:return: List of all XCom values if found.
"""
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ from airflow.serialization.serde import deserialize
msg = SUPERVISOR_COMMS.send(
msg=GetXComSequenceSlice(
@@ -306,7 +307,10 @@
if not isinstance(msg, XComSequenceSliceResult):
raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}")
- return msg.root
+ result = deserialize(msg.root)
+ if not result:
+ return None
+ return result
@staticmethod
def serialize_value(
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 6c6e597..bdd3b24 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -351,17 +351,19 @@
# If map_indexes is not specified, pull xcoms from all map indexes for each task
if isinstance(map_indexes, ArgNotSet):
- xcoms = [
- value
- for t_id in task_ids
- for value in XCom.get_all(
+ xcoms: list[Any] = []
+ for t_id in task_ids:
+ values = XCom.get_all(
run_id=run_id,
key=key,
task_id=t_id,
dag_id=dag_id,
)
- ]
+ if values is None:
+ xcoms.append(None)
+ else:
+ xcoms.extend(values)
# For single task pulling from unmapped task, return single value
if single_task_requested and len(xcoms) == 1:
return xcoms[0]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f2ff896..362b89e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1355,6 +1355,7 @@
pytest.param("hello", id="string_value"),
pytest.param("'hello'", id="quoted_string_value"),
pytest.param({"key": "value"}, id="json_value"),
+ pytest.param([], id="empty_list_no_xcoms_found"),
pytest.param((1, 2, 3), id="tuple_int_value"),
pytest.param([1, 2, 3], id="list_int_value"),
pytest.param(42, id="int_value"),
@@ -1377,6 +1378,9 @@
"""
map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes": map_indexes}
task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids": task_ids}
+ from airflow.serialization.serde import deserialize
+
+ spy_agency.spy_on(deserialize)
class CustomOperator(BaseOperator):
def execute(self, context):
@@ -1402,6 +1406,7 @@
mock_supervisor_comms.send.side_effect = mock_send_side_effect
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
+ spy_agency.assert_spy_called_with(deserialize, ser_value)
if not isinstance(task_ids, Iterable) or isinstance(task_ids, str):
task_ids = [task_ids]
@@ -1507,6 +1512,54 @@
assert mock_get_one.called
assert not mock_get_all.called
+ @pytest.mark.parametrize(
+ "api_return_value",
+ [
+ pytest.param(("data", "test_value"), id="api returns tuple"),
+ pytest.param({"data": "test_value"}, id="api returns dict"),
+ pytest.param(None, id="api returns None, no xcom found"),
+ ],
+ )
+ def test_xcom_pull_with_no_map_index(
+ self,
+ api_return_value,
+ create_runtime_ti,
+ mock_supervisor_comms,
+ ):
+ """
+ Test xcom_pull when map_indexes is not specified, so that XCom.get_all is called.
+ The test also tests if the response is deserialized and returned.
+ """
+ test_task_id = "pull_task"
+ task = BaseOperator(task_id=test_task_id)
+ runtime_ti = create_runtime_ti(task=task)
+
+ ser_value = BaseXCom.serialize_value(api_return_value)
+
+ def mock_send_side_effect(*args, **kwargs):
+ msg = kwargs.get("msg") or args[0]
+ if isinstance(msg, GetXComSequenceSlice):
+ return XComSequenceSliceResult(root=[ser_value])
+ return XComResult(key="test_key", value=None)
+
+ mock_supervisor_comms.send.side_effect = mock_send_side_effect
+ result = runtime_ti.xcom_pull(key="test_key", task_ids="task_a")
+
+ # if the API returns a tuple or dict, the below assertion assures that the value is deserialized correctly by XCom.get_all
+ assert result == api_return_value
+
+ mock_supervisor_comms.send.assert_called_once_with(
+ msg=GetXComSequenceSlice(
+ key="test_key",
+ dag_id=runtime_ti.dag_id,
+ run_id=runtime_ti.run_id,
+ task_id="task_a",
+ start=None,
+ stop=None,
+ step=None,
+ ),
+ )
+
def test_get_param_from_context(
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
):