Unify lazy db sequence implementations (#39426)
diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py
index 6500d80..332a04e 100644
--- a/airflow/lineage/__init__.py
+++ b/airflow/lineage/__init__.py
@@ -130,10 +130,10 @@
# Remove auto and task_ids
self.inlets = [i for i in self.inlets if not isinstance(i, str)]
- # We manually create a session here since xcom_pull returns a LazyXComAccess iterator.
- # If we do not pass a session a new session will be created, however that session will not be
- # properly closed and will remain open. After we are done iterating we can safely close this
- # session.
+ # We manually create a session here since xcom_pull returns a
+ # LazySelectSequence proxy. If we do not pass a session, a new one
+ # will be created, but that session will not be properly closed.
+ # After we are done iterating, we can safely close this session.
with create_session() as session:
_inlets = self.xcom_pull(
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 1cde7d7..e9f86f8 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from typing import Any
+from typing import TYPE_CHECKING, Any
from sqlalchemy import Column, Integer, MetaData, String, text
from sqlalchemy.orm import registry
@@ -48,7 +48,10 @@
mapper_registry = registry(metadata=metadata)
_sentinel = object()
-Base: Any = mapper_registry.generate_base()
+if TYPE_CHECKING:
+ Base = Any
+else:
+ Base = mapper_registry.generate_base()
ID_LEN = 250
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index ef2a41a..1a9d1e0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -97,7 +97,7 @@
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
-from airflow.models.xcom import LazyXComAccess, XCom
+from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
@@ -3358,34 +3358,37 @@
return default
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)
- query = query.order_by(None).order_by(XCom.map_index.asc())
- return LazyXComAccess.build_from_xcom_query(query)
+ return LazyXComSelectSequence.from_select(
+ query.with_entities(XCom.value).order_by(None).statement,
+ order_by=[XCom.map_index],
+ session=session,
+ )
# At this point either task_ids or map_indexes is explicitly multi-value.
# Order return values to match task_ids and map_indexes ordering.
- query = query.order_by(None)
+ ordering = []
if task_ids is None or isinstance(task_ids, str):
- query = query.order_by(XCom.task_id)
+ ordering.append(XCom.task_id)
+ elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
+ ordering.append(case(task_id_whens, value=XCom.task_id))
else:
- task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
- if task_id_whens:
- query = query.order_by(case(task_id_whens, value=XCom.task_id))
- else:
- query = query.order_by(XCom.task_id)
+ ordering.append(XCom.task_id)
if map_indexes is None or isinstance(map_indexes, int):
- query = query.order_by(XCom.map_index)
+ ordering.append(XCom.map_index)
elif isinstance(map_indexes, range):
order = XCom.map_index
if map_indexes.step < 0:
order = order.desc()
- query = query.order_by(order)
+ ordering.append(order)
+ elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
+ ordering.append(case(map_index_whens, value=XCom.map_index))
else:
- map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
- if map_index_whens:
- query = query.order_by(case(map_index_whens, value=XCom.map_index))
- else:
- query = query.order_by(XCom.map_index)
- return LazyXComAccess.build_from_xcom_query(query)
+ ordering.append(XCom.map_index)
+ return LazyXComSelectSequence.from_select(
+ query.with_entities(XCom.value).order_by(None).statement,
+ order_by=ordering,
+ session=session,
+ )
@provide_session
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 7a6695c..fe1ebad 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -17,18 +17,14 @@
# under the License.
from __future__ import annotations
-import collections.abc
-import contextlib
import inspect
-import itertools
import json
import logging
import pickle
import warnings
-from functools import cached_property, wraps
-from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
+from functools import wraps
+from typing import TYPE_CHECKING, Any, Iterable, cast, overload
-import attr
from sqlalchemy import (
Column,
ForeignKeyConstraint,
@@ -38,6 +34,7 @@
PrimaryKeyConstraint,
String,
delete,
+ select,
text,
)
from sqlalchemy.dialects.mysql import LONGBLOB
@@ -45,12 +42,12 @@
from sqlalchemy.orm import Query, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
-from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
from airflow.utils import timezone
+from airflow.utils.db import LazySelectSequence
from airflow.utils.helpers import exactly_one, is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -70,7 +67,9 @@
import datetime
import pendulum
+ from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
+ from sqlalchemy.sql.expression import Select, TextClause
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -222,11 +221,11 @@
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
- # Seamlessly resolve LazyXComAccess to a list. This is intended to work
+ # Seamlessly resolve LazySelectSequence to a list. This intends to work
# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
# it's pushed into XCom, the user should be aware of the performance
# implications, and this avoids leaking the implementation detail.
- if isinstance(value, LazyXComAccess):
+ if isinstance(value, LazySelectSequence):
warning_message = (
"Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
"to list, which may degrade performance. Review resource "
@@ -716,111 +715,19 @@
return BaseXCom._deserialize_value(self, True)
-class _LazyXComAccessIterator(collections.abc.Iterator):
- def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
- self._cm = cm
- self._entered = False
-
- def __del__(self) -> None:
- if self._entered:
- self._cm.__exit__(None, None, None)
-
- def __iter__(self) -> collections.abc.Iterator:
- return self
-
- def __next__(self) -> Any:
- return XCom.deserialize_value(next(self._it))
-
- @cached_property
- def _it(self) -> collections.abc.Iterator:
- self._entered = True
- return iter(self._cm.__enter__())
-
-
-@attr.define(slots=True)
-class LazyXComAccess(collections.abc.Sequence):
- """Wrapper to lazily pull XCom with a sequence-like interface.
-
- Note that since the session bound to the parent query may have died when we
- actually access the sequence's content, we must create a new session
- for every function call with ``with_session()``.
+class LazyXComSelectSequence(LazySelectSequence[Any]):
+ """List-like interface to lazily access XCom values.
:meta private:
"""
- _query: Query
- _len: int | None = attr.ib(init=False, default=None)
+ @staticmethod
+ def _rebuild_select(stmt: TextClause) -> Select:
+ return select(XCom.value).from_statement(stmt)
- @classmethod
- def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
- return cls(query=query.with_entities(XCom.value))
-
- def __repr__(self) -> str:
- return f"LazyXComAccess([{len(self)} items])"
-
- def __str__(self) -> str:
- return str(list(self))
-
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, (list, LazyXComAccess)):
- z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
- return all(x == y for x, y in z)
- return NotImplemented
-
- def __getstate__(self) -> Any:
- # We don't want to go to the trouble of serializing the entire Query
- # object, including its filters, hints, etc. (plus SQLAlchemy does not
- # provide a public API to inspect a query's contents). Converting the
- # query into a SQL string is the best we can get. Theoratically we can
- # do the same for count(), but I think it should be performant enough to
- # calculate only that eagerly.
- with self._get_bound_query() as query:
- statement = query.statement.compile(
- query.session.get_bind(),
- # This inlines all the values into the SQL string to simplify
- # cross-process commuinication as much as possible.
- compile_kwargs={"literal_binds": True},
- )
- return (str(statement), query.count())
-
- def __setstate__(self, state: Any) -> None:
- statement, self._len = state
- self._query = Query(XCom.value).from_statement(text(statement))
-
- def __len__(self):
- if self._len is None:
- with self._get_bound_query() as query:
- self._len = query.count()
- return self._len
-
- def __iter__(self):
- return _LazyXComAccessIterator(self._get_bound_query())
-
- def __getitem__(self, key):
- if not isinstance(key, int):
- raise ValueError("only support index access for now")
- try:
- with self._get_bound_query() as query:
- r = query.offset(key).limit(1).one()
- except NoResultFound:
- raise IndexError(key) from None
- return XCom.deserialize_value(r)
-
- @contextlib.contextmanager
- def _get_bound_query(self) -> Generator[Query, None, None]:
- # Do we have a valid session already?
- if self._query.session and self._query.session.is_active:
- yield self._query
- return
-
- Session = getattr(settings, "Session", None)
- if Session is None:
- raise RuntimeError("Session must be set before!")
- session = Session()
- try:
- yield self._query.with_session(session)
- finally:
- session.close()
+ @staticmethod
+ def _process_row(row: Row) -> Any:
+ return XCom.deserialize_value(row)
def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py
index 5ae2d23..ba96c92 100644
--- a/airflow/typing_compat.py
+++ b/airflow/typing_compat.py
@@ -23,6 +23,7 @@
"Literal",
"ParamSpec",
"Protocol",
+ "Self",
"TypedDict",
"TypeGuard",
"runtime_checkable",
@@ -45,3 +46,8 @@
from typing import ParamSpec, TypeGuard
else:
from typing_extensions import ParamSpec, TypeGuard
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 2b73020..58e688f 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -32,25 +32,26 @@
KeysView,
Mapping,
MutableMapping,
- Sequence,
SupportsIndex,
ValuesView,
- overload,
)
import attrs
import lazy_object_proxy
+from sqlalchemy import select
from airflow.datasets import Dataset, coerce_to_uri
from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.models.dataset import DatasetEvent, DatasetModel
+from airflow.utils.db import LazySelectSequence
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
+ from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
- from sqlalchemy.sql.expression import Select
+ from sqlalchemy.sql.expression import Select, TextClause
from airflow.models.baseoperator import BaseOperator
- from airflow.models.dataset import DatasetEvent
# NOTE: Please keep this in sync with the following:
# * Context in airflow/utils/context.pyi.
@@ -187,57 +188,23 @@
return self._dict[uri]
-@attrs.define()
-class InletEventsAccessor(Sequence["DatasetEvent"]):
- """Lazy sequence to access inlet dataset events.
+class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]):
+ """List-like interface to lazily access DatasetEvent rows.
:meta private:
"""
- _uri: str
- _session: Session
+ @staticmethod
+ def _rebuild_select(stmt: TextClause) -> Select:
+ return select(DatasetEvent).from_statement(stmt)
- def _get_select_stmt(self, *, reverse: bool = False) -> Select:
- from sqlalchemy import select
-
- from airflow.models.dataset import DatasetEvent, DatasetModel
-
- stmt = select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == self._uri)
- if reverse:
- return stmt.order_by(DatasetEvent.timestamp.desc())
- return stmt.order_by(DatasetEvent.timestamp.asc())
-
- def __reversed__(self) -> Iterator[DatasetEvent]:
- return iter(self._session.scalar(self._get_select_stmt(reverse=True)))
-
- def __iter__(self) -> Iterator[DatasetEvent]:
- return iter(self._session.scalar(self._get_select_stmt()))
-
- @overload
- def __getitem__(self, key: int) -> DatasetEvent: ...
-
- @overload
- def __getitem__(self, key: slice) -> Sequence[DatasetEvent]: ...
-
- def __getitem__(self, key: int | slice) -> DatasetEvent | Sequence[DatasetEvent]:
- if not isinstance(key, int):
- raise ValueError("non-index access is not supported")
- if key >= 0:
- stmt = self._get_select_stmt().offset(key)
- else:
- stmt = self._get_select_stmt(reverse=True).offset(-1 - key)
- if (event := self._session.scalar(stmt.limit(1))) is not None:
- return event
- raise IndexError(key)
-
- def __len__(self) -> int:
- from sqlalchemy import func, select
-
- return self._session.scalar(select(func.count()).select_from(self._get_select_stmt()))
+ @staticmethod
+ def _process_row(row: Row) -> DatasetEvent:
+ return row[0]
@attrs.define(init=False)
-class InletEventsAccessors(Mapping[str, InletEventsAccessor]):
+class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]):
"""Lazy mapping for inlet dataset events accessors.
:meta private:
@@ -258,14 +225,18 @@
def __len__(self) -> int:
return len(self._inlets)
- def __getitem__(self, key: int | str | Dataset) -> InletEventsAccessor:
+ def __getitem__(self, key: int | str | Dataset) -> LazyDatasetEventSelectSequence:
if isinstance(key, int): # Support index access; it's easier for trivial cases.
dataset = self._inlets[key]
if not isinstance(dataset, Dataset):
raise IndexError(key)
else:
dataset = self._datasets[coerce_to_uri(key)]
- return InletEventsAccessor(dataset.uri, session=self._session)
+ return LazyDatasetEventSelectSequence.from_select(
+ select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == dataset.uri),
+ order_by=[DatasetEvent.timestamp],
+ session=self._session,
+ )
class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index e7afb04..fb64169 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -17,8 +17,10 @@
# under the License.
from __future__ import annotations
+import collections.abc
import contextlib
import enum
+import itertools
import json
import logging
import os
@@ -27,8 +29,20 @@
import warnings
from dataclasses import dataclass
from tempfile import gettempdir
-from typing import TYPE_CHECKING, Callable, Generator, Iterable
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generator,
+ Iterable,
+ Iterator,
+ Protocol,
+ Sequence,
+ TypeVar,
+ overload,
+)
+import attrs
from sqlalchemy import (
Table,
and_,
@@ -54,16 +68,28 @@
# TODO: remove create_session once we decide to break backward compatibility
from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
+from airflow.utils.task_instance_session import get_current_task_instance_session
if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
+ from sqlalchemy.engine import Row
from sqlalchemy.orm import Query, Session
- from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.elements import ClauseElement, TextClause
from sqlalchemy.sql.selectable import Select
- from airflow.models.base import Base
from airflow.models.connection import Connection
+ from airflow.typing_compat import Self
+
+ # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2.
+ # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol
+ class MappedClassProtocol(Protocol):
+ """Protocol for SQLALchemy model base."""
+
+ __tablename__: str
+
+
+T = TypeVar("T")
log = logging.getLogger(__name__)
@@ -1028,7 +1054,7 @@
)
-def reflect_tables(tables: list[Base | str] | None, session):
+def reflect_tables(tables: list[MappedClassProtocol | str] | None, session):
"""
When running checks prior to upgrades, we use reflection to determine current state of the database.
@@ -1416,7 +1442,7 @@
ref_table="task_instance",
)
- models_list: list[tuple[Base, str, BadReferenceConfig]] = [
+ models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [
(TaskInstance, "2.2", missing_dag_run_config),
(TaskReschedule, "2.2", missing_ti_config),
(RenderedTaskInstanceFields, "2.3", missing_ti_config),
@@ -1875,7 +1901,7 @@
def get_query_count(query_stmt: Select, *, session: Session) -> int:
- """Get count of query.
+ """Get count of a query.
A SELECT COUNT() FROM is issued against the subquery built from the
given statement. The ORDER BY clause is stripped from the statement
@@ -1888,8 +1914,21 @@
return session.scalar(count_stmt)
+def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
+ """Check whether there is at least one row matching a query.
+
+ A SELECT 1 FROM is issued against the subquery built from the given
+ statement. The ORDER BY clause is stripped from the statement since it's
+ unnecessary, and can impact query planning and degrade performance.
+
+ :meta private:
+ """
+ count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery())
+ return session.scalar(count_stmt)
+
+
def exists_query(*where: ClauseElement, session: Session) -> bool:
- """Check whether there is at least one row matching given clause.
+ """Check whether there is at least one row matching given clauses.
This does a SELECT 1 WHERE ... LIMIT 1 and check the result.
@@ -1897,3 +1936,122 @@
"""
stmt = select(literal(True)).where(*where).limit(1)
return session.scalar(stmt) is not None
+
+
+@attrs.define(slots=True)
+class LazySelectSequence(Sequence[T]):
+ """List-like interface to lazily access a database model query.
+
+ The intended use case is inside a task execution context, where we manage an
+ active SQLAlchemy session in the background.
+
+ This is an abstract base class. Each use case should subclass, and implement
+ the following static methods:
+
+ * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it
+ is not easy to pickle SQLAlchemy constructs, this class serializes the
+ SELECT statements into plain text to storage. This method is called on
+ deserialization to convert the textual clause back into an ORM SELECT.
+ * ``_process_row`` is called when an item is accessed. The lazy sequence
+ uses ``session.execute()`` to fetch rows from the database, and this
+ method should know how to process each row into a value.
+
+ :meta private:
+ """
+
+ _select_asc: ClauseElement
+ _select_desc: ClauseElement
+ _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session)
+ _len: int | None = attrs.field(init=False, default=None)
+
+ @classmethod
+ def from_select(
+ cls,
+ select: Select,
+ *,
+ order_by: Sequence[ClauseElement],
+ session: Session | None = None,
+ ) -> Self:
+ s1 = select
+ for col in order_by:
+ s1 = s1.order_by(col.asc())
+ s2 = select
+ for col in order_by:
+ s2 = s2.order_by(col.desc())
+ return cls(s1, s2, session=session or get_current_task_instance_session())
+
+ @staticmethod
+ def _rebuild_select(stmt: TextClause) -> Select:
+ """Rebuild a textual statement into an ORM-configured SELECT statement.
+
+ This should do something like ``select(field).from_statement(stmt)`` to
+ reconfigure ORM information to the textual SQL statement.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def _process_row(row: Row) -> T:
+ """Process a SELECT-ed row into the end value."""
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ counter = "item" if (length := len(self)) == 1 else "items"
+ return f"LazySelectSequence([{length} {counter}])"
+
+ def __str__(self) -> str:
+ counter = "item" if (length := len(self)) == 1 else "items"
+ return f"LazySelectSequence([{length} {counter}])"
+
+ def __getstate__(self) -> Any:
+ # We don't want to go to the trouble of serializing SQLAlchemy objects.
+ # Converting the statement into a SQL string is the best we can get.
+ # The literal_binds compile argument inlines all the values into the SQL
+ # string to simplify cross-process commuinication as much as possible.
+ # Theoratically we can do the same for count(), but I think it should be
+ # performant enough to calculate only that eagerly.
+ s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
+ s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
+ return (s1, s2, len(self))
+
+ def __setstate__(self, state: Any) -> None:
+ s1, s2, self._len = state
+ self._select_asc = self._rebuild_select(text(s1))
+ self._select_desc = self._rebuild_select(text(s2))
+ self._session = get_current_task_instance_session()
+
+ def __bool__(self) -> bool:
+ return check_query_exists(self._select_asc, session=self._session)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, collections.abc.Sequence):
+ return NotImplemented
+ z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
+ return all(x == y for x, y in z)
+
+ def __reversed__(self) -> Iterator[T]:
+ return iter(self._process_row(r) for r in self._session.execute(self._select_desc))
+
+ def __iter__(self) -> Iterator[T]:
+ return iter(self._process_row(r) for r in self._session.execute(self._select_asc))
+
+ def __len__(self) -> int:
+ if self._len is None:
+ self._len = get_query_count(self._select_asc, session=self._session)
+ return self._len
+
+ @overload
+ def __getitem__(self, key: int) -> T: ...
+
+ @overload
+ def __getitem__(self, key: slice) -> Self: ...
+
+ def __getitem__(self, key: int | slice) -> T | Self:
+ if not isinstance(key, int):
+ raise ValueError("non-index access is not supported")
+ if key >= 0:
+ stmt = self._select_asc.offset(key)
+ else:
+ stmt = self._select_desc.offset(-1 - key)
+ if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None:
+ raise IndexError(key)
+ return self._process_row(row)
diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py
index 9d4dd95..bb9741b 100644
--- a/airflow/utils/task_instance_session.py
+++ b/airflow/utils/task_instance_session.py
@@ -22,7 +22,7 @@
import traceback
from typing import TYPE_CHECKING
-from airflow.utils.session import create_session
+from airflow import settings
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@
log.warning('File: "%s", %s , in %s', filename, line_number, name)
if line:
log.warning(" %s", line.strip())
- __current_task_instance_session = create_session()
+ __current_task_instance_session = settings.Session()
return __current_task_instance_session
diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst
index 739bc6f..dd6f42b 100644
--- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst
@@ -53,7 +53,7 @@
In the above example, ``values`` received by ``sum_it`` is an aggregation of all values returned by each mapped instance of ``add_one``. However, since it is impossible to know how many instances of ``add_one`` we will have in advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves each individual value only when asked. Therefore, if you run ``print(values)`` directly, you would get something like this::
- LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one')
+ LazySelectSequence([15 items])
You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but since this would eagerly load values from *all* of the referenced upstream mapped tasks, you must be aware of the potential performance implications if the mapped number is large.
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 8afc1ab..11d833a 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -73,7 +73,7 @@
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.variable import Variable
-from airflow.models.xcom import LazyXComAccess, XCom
+from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
@@ -93,6 +93,7 @@
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
+from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.types import DagRunType
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER
@@ -4355,20 +4356,22 @@
run: DagRun = dag_maker.create_dagrun()
run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session)
- query = session.query(XCom.value).filter_by(
- dag_id=run.dag_id,
- run_id=run.run_id,
- task_id="t",
- map_index=-1,
- key="xxx",
- )
-
- original = LazyXComAccess.build_from_xcom_query(query)
- processed = pickle.loads(pickle.dumps(original))
+ with set_current_task_instance_session(session=session):
+ original = LazyXComSelectSequence.from_select(
+ select(XCom.value).filter_by(
+ dag_id=run.dag_id,
+ run_id=run.run_id,
+ task_id="t",
+ map_index=-1,
+ key="xxx",
+ ),
+ order_by=(),
+ )
+ processed = pickle.loads(pickle.dumps(original))
# After the object went through pickling, the underlying ORM query should be
# replaced by one backed by a literal SQL string with all variables binded.
- sql_lines = [line.strip() for line in str(processed._query.statement.compile(None)).splitlines()]
+ sql_lines = [line.strip() for line in str(processed._select_asc.compile(None)).splitlines()]
assert sql_lines == _get_lazy_xcom_access_expected_sql_lines()
assert len(processed) == 1
@@ -4398,7 +4401,7 @@
# Simply pulling the joined XCom value should not deserialize.
joined = ti_2.xcom_pull("task_1", session=session)
- assert isinstance(joined, LazyXComAccess)
+ assert isinstance(joined, LazyXComSelectSequence)
assert mock_deserialize_value.call_count == 0
# Only when we go through the iterable does deserialization happen.