Unify lazy db sequence implementations (#39426)

diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py
index 6500d80..332a04e 100644
--- a/airflow/lineage/__init__.py
+++ b/airflow/lineage/__init__.py
@@ -130,10 +130,10 @@
             # Remove auto and task_ids
             self.inlets = [i for i in self.inlets if not isinstance(i, str)]
 
-            # We manually create a session here since xcom_pull returns a LazyXComAccess iterator.
-            # If we do not pass a session a new session will be created, however that session will not be
-            # properly closed and will remain open. After we are done iterating we can safely close this
-            # session.
+            # We manually create a session here since xcom_pull returns a
+            # LazySelectSequence proxy. If we do not pass a session, a new one
+            # will be created, but that session will not be properly closed.
+            # After we are done iterating, we can safely close this session.
             with create_session() as session:
                 _inlets = self.xcom_pull(
                     context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 1cde7d7..e9f86f8 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from sqlalchemy import Column, Integer, MetaData, String, text
 from sqlalchemy.orm import registry
@@ -48,7 +48,10 @@
 mapper_registry = registry(metadata=metadata)
 _sentinel = object()
 
-Base: Any = mapper_registry.generate_base()
+if TYPE_CHECKING:
+    Base = Any
+else:
+    Base = mapper_registry.generate_base()
 
 ID_LEN = 250
 
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index ef2a41a..1a9d1e0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -97,7 +97,7 @@
 from airflow.models.taskinstancekey import TaskInstanceKey
 from airflow.models.taskmap import TaskMap
 from airflow.models.taskreschedule import TaskReschedule
-from airflow.models.xcom import LazyXComAccess, XCom
+from airflow.models.xcom import LazyXComSelectSequence, XCom
 from airflow.plugins_manager import integrate_macros_plugins
 from airflow.sentry import Sentry
 from airflow.settings import task_instance_mutation_hook
@@ -3358,34 +3358,37 @@
                 return default
             if map_indexes is not None or first.map_index < 0:
                 return XCom.deserialize_value(first)
-            query = query.order_by(None).order_by(XCom.map_index.asc())
-            return LazyXComAccess.build_from_xcom_query(query)
+            return LazyXComSelectSequence.from_select(
+                query.with_entities(XCom.value).order_by(None).statement,
+                order_by=[XCom.map_index],
+                session=session,
+            )
 
         # At this point either task_ids or map_indexes is explicitly multi-value.
         # Order return values to match task_ids and map_indexes ordering.
-        query = query.order_by(None)
+        ordering = []
         if task_ids is None or isinstance(task_ids, str):
-            query = query.order_by(XCom.task_id)
+            ordering.append(XCom.task_id)
+        elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
+            ordering.append(case(task_id_whens, value=XCom.task_id))
         else:
-            task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
-            if task_id_whens:
-                query = query.order_by(case(task_id_whens, value=XCom.task_id))
-            else:
-                query = query.order_by(XCom.task_id)
+            ordering.append(XCom.task_id)
         if map_indexes is None or isinstance(map_indexes, int):
-            query = query.order_by(XCom.map_index)
+            ordering.append(XCom.map_index)
         elif isinstance(map_indexes, range):
             order = XCom.map_index
             if map_indexes.step < 0:
                 order = order.desc()
-            query = query.order_by(order)
+            ordering.append(order)
+        elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
+            ordering.append(case(map_index_whens, value=XCom.map_index))
         else:
-            map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
-            if map_index_whens:
-                query = query.order_by(case(map_index_whens, value=XCom.map_index))
-            else:
-                query = query.order_by(XCom.map_index)
-        return LazyXComAccess.build_from_xcom_query(query)
+            ordering.append(XCom.map_index)
+        return LazyXComSelectSequence.from_select(
+            query.with_entities(XCom.value).order_by(None).statement,
+            order_by=ordering,
+            session=session,
+        )
 
     @provide_session
     def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 7a6695c..fe1ebad 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -17,18 +17,14 @@
 # under the License.
 from __future__ import annotations
 
-import collections.abc
-import contextlib
 import inspect
-import itertools
 import json
 import logging
 import pickle
 import warnings
-from functools import cached_property, wraps
-from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
+from functools import wraps
+from typing import TYPE_CHECKING, Any, Iterable, cast, overload
 
-import attr
 from sqlalchemy import (
     Column,
     ForeignKeyConstraint,
@@ -38,6 +34,7 @@
     PrimaryKeyConstraint,
     String,
     delete,
+    select,
     text,
 )
 from sqlalchemy.dialects.mysql import LONGBLOB
@@ -45,12 +42,12 @@
 from sqlalchemy.orm import Query, reconstructor, relationship
 from sqlalchemy.orm.exc import NoResultFound
 
-from airflow import settings
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
 from airflow.utils import timezone
+from airflow.utils.db import LazySelectSequence
 from airflow.utils.helpers import exactly_one, is_container
 from airflow.utils.json import XComDecoder, XComEncoder
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -70,7 +67,9 @@
     import datetime
 
     import pendulum
+    from sqlalchemy.engine import Row
     from sqlalchemy.orm import Session
+    from sqlalchemy.sql.expression import Select, TextClause
 
     from airflow.models.taskinstancekey import TaskInstanceKey
 
@@ -222,11 +221,11 @@
             if dag_run_id is None:
                 raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
 
-        # Seamlessly resolve LazyXComAccess to a list. This is intended to work
+        # Seamlessly resolve LazySelectSequence to a list. This intends to work
         # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
         # it's pushed into XCom, the user should be aware of the performance
         # implications, and this avoids leaking the implementation detail.
-        if isinstance(value, LazyXComAccess):
+        if isinstance(value, LazySelectSequence):
             warning_message = (
                 "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
                 "to list, which may degrade performance. Review resource "
@@ -716,111 +715,19 @@
         return BaseXCom._deserialize_value(self, True)
 
 
-class _LazyXComAccessIterator(collections.abc.Iterator):
-    def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
-        self._cm = cm
-        self._entered = False
-
-    def __del__(self) -> None:
-        if self._entered:
-            self._cm.__exit__(None, None, None)
-
-    def __iter__(self) -> collections.abc.Iterator:
-        return self
-
-    def __next__(self) -> Any:
-        return XCom.deserialize_value(next(self._it))
-
-    @cached_property
-    def _it(self) -> collections.abc.Iterator:
-        self._entered = True
-        return iter(self._cm.__enter__())
-
-
-@attr.define(slots=True)
-class LazyXComAccess(collections.abc.Sequence):
-    """Wrapper to lazily pull XCom with a sequence-like interface.
-
-    Note that since the session bound to the parent query may have died when we
-    actually access the sequence's content, we must create a new session
-    for every function call with ``with_session()``.
+class LazyXComSelectSequence(LazySelectSequence[Any]):
+    """List-like interface to lazily access XCom values.
 
     :meta private:
     """
 
-    _query: Query
-    _len: int | None = attr.ib(init=False, default=None)
+    @staticmethod
+    def _rebuild_select(stmt: TextClause) -> Select:
+        return select(XCom.value).from_statement(stmt)
 
-    @classmethod
-    def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
-        return cls(query=query.with_entities(XCom.value))
-
-    def __repr__(self) -> str:
-        return f"LazyXComAccess([{len(self)} items])"
-
-    def __str__(self) -> str:
-        return str(list(self))
-
-    def __eq__(self, other: Any) -> bool:
-        if isinstance(other, (list, LazyXComAccess)):
-            z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
-            return all(x == y for x, y in z)
-        return NotImplemented
-
-    def __getstate__(self) -> Any:
-        # We don't want to go to the trouble of serializing the entire Query
-        # object, including its filters, hints, etc. (plus SQLAlchemy does not
-        # provide a public API to inspect a query's contents). Converting the
-        # query into a SQL string is the best we can get. Theoratically we can
-        # do the same for count(), but I think it should be performant enough to
-        # calculate only that eagerly.
-        with self._get_bound_query() as query:
-            statement = query.statement.compile(
-                query.session.get_bind(),
-                # This inlines all the values into the SQL string to simplify
-                # cross-process commuinication as much as possible.
-                compile_kwargs={"literal_binds": True},
-            )
-            return (str(statement), query.count())
-
-    def __setstate__(self, state: Any) -> None:
-        statement, self._len = state
-        self._query = Query(XCom.value).from_statement(text(statement))
-
-    def __len__(self):
-        if self._len is None:
-            with self._get_bound_query() as query:
-                self._len = query.count()
-        return self._len
-
-    def __iter__(self):
-        return _LazyXComAccessIterator(self._get_bound_query())
-
-    def __getitem__(self, key):
-        if not isinstance(key, int):
-            raise ValueError("only support index access for now")
-        try:
-            with self._get_bound_query() as query:
-                r = query.offset(key).limit(1).one()
-        except NoResultFound:
-            raise IndexError(key) from None
-        return XCom.deserialize_value(r)
-
-    @contextlib.contextmanager
-    def _get_bound_query(self) -> Generator[Query, None, None]:
-        # Do we have a valid session already?
-        if self._query.session and self._query.session.is_active:
-            yield self._query
-            return
-
-        Session = getattr(settings, "Session", None)
-        if Session is None:
-            raise RuntimeError("Session must be set before!")
-        session = Session()
-        try:
-            yield self._query.with_session(session)
-        finally:
-            session.close()
+    @staticmethod
+    def _process_row(row: Row) -> Any:
+        return XCom.deserialize_value(row)
 
 
 def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py
index 5ae2d23..ba96c92 100644
--- a/airflow/typing_compat.py
+++ b/airflow/typing_compat.py
@@ -23,6 +23,7 @@
     "Literal",
     "ParamSpec",
     "Protocol",
+    "Self",
     "TypedDict",
     "TypeGuard",
     "runtime_checkable",
@@ -45,3 +46,8 @@
     from typing import ParamSpec, TypeGuard
 else:
     from typing_extensions import ParamSpec, TypeGuard
+
+if sys.version_info >= (3, 11):
+    from typing import Self
+else:
+    from typing_extensions import Self
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 2b73020..58e688f 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -32,25 +32,26 @@
     KeysView,
     Mapping,
     MutableMapping,
-    Sequence,
     SupportsIndex,
     ValuesView,
-    overload,
 )
 
 import attrs
 import lazy_object_proxy
+from sqlalchemy import select
 
 from airflow.datasets import Dataset, coerce_to_uri
 from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.models.dataset import DatasetEvent, DatasetModel
+from airflow.utils.db import LazySelectSequence
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
+    from sqlalchemy.engine import Row
     from sqlalchemy.orm import Session
-    from sqlalchemy.sql.expression import Select
+    from sqlalchemy.sql.expression import Select, TextClause
 
     from airflow.models.baseoperator import BaseOperator
-    from airflow.models.dataset import DatasetEvent
 
 # NOTE: Please keep this in sync with the following:
 # * Context in airflow/utils/context.pyi.
@@ -187,57 +188,23 @@
         return self._dict[uri]
 
 
-@attrs.define()
-class InletEventsAccessor(Sequence["DatasetEvent"]):
-    """Lazy sequence to access inlet dataset events.
+class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]):
+    """List-like interface to lazily access DatasetEvent rows.
 
     :meta private:
     """
 
-    _uri: str
-    _session: Session
+    @staticmethod
+    def _rebuild_select(stmt: TextClause) -> Select:
+        return select(DatasetEvent).from_statement(stmt)
 
-    def _get_select_stmt(self, *, reverse: bool = False) -> Select:
-        from sqlalchemy import select
-
-        from airflow.models.dataset import DatasetEvent, DatasetModel
-
-        stmt = select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == self._uri)
-        if reverse:
-            return stmt.order_by(DatasetEvent.timestamp.desc())
-        return stmt.order_by(DatasetEvent.timestamp.asc())
-
-    def __reversed__(self) -> Iterator[DatasetEvent]:
-        return iter(self._session.scalar(self._get_select_stmt(reverse=True)))
-
-    def __iter__(self) -> Iterator[DatasetEvent]:
-        return iter(self._session.scalar(self._get_select_stmt()))
-
-    @overload
-    def __getitem__(self, key: int) -> DatasetEvent: ...
-
-    @overload
-    def __getitem__(self, key: slice) -> Sequence[DatasetEvent]: ...
-
-    def __getitem__(self, key: int | slice) -> DatasetEvent | Sequence[DatasetEvent]:
-        if not isinstance(key, int):
-            raise ValueError("non-index access is not supported")
-        if key >= 0:
-            stmt = self._get_select_stmt().offset(key)
-        else:
-            stmt = self._get_select_stmt(reverse=True).offset(-1 - key)
-        if (event := self._session.scalar(stmt.limit(1))) is not None:
-            return event
-        raise IndexError(key)
-
-    def __len__(self) -> int:
-        from sqlalchemy import func, select
-
-        return self._session.scalar(select(func.count()).select_from(self._get_select_stmt()))
+    @staticmethod
+    def _process_row(row: Row) -> DatasetEvent:
+        return row[0]
 
 
 @attrs.define(init=False)
-class InletEventsAccessors(Mapping[str, InletEventsAccessor]):
+class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]):
     """Lazy mapping for inlet dataset events accessors.
 
     :meta private:
@@ -258,14 +225,18 @@
     def __len__(self) -> int:
         return len(self._inlets)
 
-    def __getitem__(self, key: int | str | Dataset) -> InletEventsAccessor:
+    def __getitem__(self, key: int | str | Dataset) -> LazyDatasetEventSelectSequence:
         if isinstance(key, int):  # Support index access; it's easier for trivial cases.
             dataset = self._inlets[key]
             if not isinstance(dataset, Dataset):
                 raise IndexError(key)
         else:
             dataset = self._datasets[coerce_to_uri(key)]
-        return InletEventsAccessor(dataset.uri, session=self._session)
+        return LazyDatasetEventSelectSequence.from_select(
+            select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == dataset.uri),
+            order_by=[DatasetEvent.timestamp],
+            session=self._session,
+        )
 
 
 class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index e7afb04..fb64169 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -17,8 +17,10 @@
 # under the License.
 from __future__ import annotations
 
+import collections.abc
 import contextlib
 import enum
+import itertools
 import json
 import logging
 import os
@@ -27,8 +29,20 @@
 import warnings
 from dataclasses import dataclass
 from tempfile import gettempdir
-from typing import TYPE_CHECKING, Callable, Generator, Iterable
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Generator,
+    Iterable,
+    Iterator,
+    Protocol,
+    Sequence,
+    TypeVar,
+    overload,
+)
 
+import attrs
 from sqlalchemy import (
     Table,
     and_,
@@ -54,16 +68,28 @@
 
 # TODO: remove create_session once we decide to break backward compatibility
 from airflow.utils.session import NEW_SESSION, create_session, provide_session  # noqa: F401
+from airflow.utils.task_instance_session import get_current_task_instance_session
 
 if TYPE_CHECKING:
     from alembic.runtime.environment import EnvironmentContext
     from alembic.script import ScriptDirectory
+    from sqlalchemy.engine import Row
     from sqlalchemy.orm import Query, Session
-    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql.elements import ClauseElement, TextClause
     from sqlalchemy.sql.selectable import Select
 
-    from airflow.models.base import Base
     from airflow.models.connection import Connection
+    from airflow.typing_compat import Self
+
+    # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2.
+    # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol
+    class MappedClassProtocol(Protocol):
+        """Protocol for SQLALchemy model base."""
+
+        __tablename__: str
+
+
+T = TypeVar("T")
 
 log = logging.getLogger(__name__)
 
@@ -1028,7 +1054,7 @@
             )
 
 
-def reflect_tables(tables: list[Base | str] | None, session):
+def reflect_tables(tables: list[MappedClassProtocol | str] | None, session):
     """
     When running checks prior to upgrades, we use reflection to determine current state of the database.
 
@@ -1416,7 +1442,7 @@
         ref_table="task_instance",
     )
 
-    models_list: list[tuple[Base, str, BadReferenceConfig]] = [
+    models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [
         (TaskInstance, "2.2", missing_dag_run_config),
         (TaskReschedule, "2.2", missing_ti_config),
         (RenderedTaskInstanceFields, "2.3", missing_ti_config),
@@ -1875,7 +1901,7 @@
 
 
 def get_query_count(query_stmt: Select, *, session: Session) -> int:
-    """Get count of query.
+    """Get count of a query.
 
     A SELECT COUNT() FROM is issued against the subquery built from the
     given statement. The ORDER BY clause is stripped from the statement
@@ -1888,8 +1914,21 @@
     return session.scalar(count_stmt)
 
 
+def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
+    """Check whether there is at least one row matching a query.
+
+    A SELECT 1 FROM is issued against the subquery built from the given
+    statement. The ORDER BY clause is stripped from the statement since it's
+    unnecessary, and can impact query planning and degrade performance.
+
+    :meta private:
+    """
+    count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery())
+    return session.scalar(count_stmt)
+
+
 def exists_query(*where: ClauseElement, session: Session) -> bool:
-    """Check whether there is at least one row matching given clause.
+    """Check whether there is at least one row matching given clauses.
 
     This does a SELECT 1 WHERE ... LIMIT 1 and check the result.
 
@@ -1897,3 +1936,122 @@
     """
     stmt = select(literal(True)).where(*where).limit(1)
     return session.scalar(stmt) is not None
+
+
+@attrs.define(slots=True)
+class LazySelectSequence(Sequence[T]):
+    """List-like interface to lazily access a database model query.
+
+    The intended use case is inside a task execution context, where we manage an
+    active SQLAlchemy session in the background.
+
+    This is an abstract base class. Each use case should subclass, and implement
+    the following static methods:
+
+    * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it
+      is not easy to pickle SQLAlchemy constructs, this class serializes the
+      SELECT statements into plain text to storage. This method is called on
+      deserialization to convert the textual clause back into an ORM SELECT.
+    * ``_process_row`` is called when an item is accessed. The lazy sequence
+      uses ``session.execute()`` to fetch rows from the database, and this
+      method should know how to process each row into a value.
+
+    :meta private:
+    """
+
+    _select_asc: ClauseElement
+    _select_desc: ClauseElement
+    _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session)
+    _len: int | None = attrs.field(init=False, default=None)
+
+    @classmethod
+    def from_select(
+        cls,
+        select: Select,
+        *,
+        order_by: Sequence[ClauseElement],
+        session: Session | None = None,
+    ) -> Self:
+        s1 = select
+        for col in order_by:
+            s1 = s1.order_by(col.asc())
+        s2 = select
+        for col in order_by:
+            s2 = s2.order_by(col.desc())
+        return cls(s1, s2, session=session or get_current_task_instance_session())
+
+    @staticmethod
+    def _rebuild_select(stmt: TextClause) -> Select:
+        """Rebuild a textual statement into an ORM-configured SELECT statement.
+
+        This should do something like ``select(field).from_statement(stmt)`` to
+        reconfigure ORM information to the textual SQL statement.
+        """
+        raise NotImplementedError
+
+    @staticmethod
+    def _process_row(row: Row) -> T:
+        """Process a SELECT-ed row into the end value."""
+        raise NotImplementedError
+
+    def __repr__(self) -> str:
+        counter = "item" if (length := len(self)) == 1 else "items"
+        return f"LazySelectSequence([{length} {counter}])"
+
+    def __str__(self) -> str:
+        counter = "item" if (length := len(self)) == 1 else "items"
+        return f"LazySelectSequence([{length} {counter}])"
+
+    def __getstate__(self) -> Any:
+        # We don't want to go to the trouble of serializing SQLAlchemy objects.
+        # Converting the statement into a SQL string is the best we can get.
+        # The literal_binds compile argument inlines all the values into the SQL
+        # string to simplify cross-process commuinication as much as possible.
+        # Theoratically we can do the same for count(), but I think it should be
+        # performant enough to calculate only that eagerly.
+        s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
+        s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
+        return (s1, s2, len(self))
+
+    def __setstate__(self, state: Any) -> None:
+        s1, s2, self._len = state
+        self._select_asc = self._rebuild_select(text(s1))
+        self._select_desc = self._rebuild_select(text(s2))
+        self._session = get_current_task_instance_session()
+
+    def __bool__(self) -> bool:
+        return check_query_exists(self._select_asc, session=self._session)
+
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, collections.abc.Sequence):
+            return NotImplemented
+        z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
+        return all(x == y for x, y in z)
+
+    def __reversed__(self) -> Iterator[T]:
+        return iter(self._process_row(r) for r in self._session.execute(self._select_desc))
+
+    def __iter__(self) -> Iterator[T]:
+        return iter(self._process_row(r) for r in self._session.execute(self._select_asc))
+
+    def __len__(self) -> int:
+        if self._len is None:
+            self._len = get_query_count(self._select_asc, session=self._session)
+        return self._len
+
+    @overload
+    def __getitem__(self, key: int) -> T: ...
+
+    @overload
+    def __getitem__(self, key: slice) -> Self: ...
+
+    def __getitem__(self, key: int | slice) -> T | Self:
+        if not isinstance(key, int):
+            raise ValueError("non-index access is not supported")
+        if key >= 0:
+            stmt = self._select_asc.offset(key)
+        else:
+            stmt = self._select_desc.offset(-1 - key)
+        if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None:
+            raise IndexError(key)
+        return self._process_row(row)
diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py
index 9d4dd95..bb9741b 100644
--- a/airflow/utils/task_instance_session.py
+++ b/airflow/utils/task_instance_session.py
@@ -22,7 +22,7 @@
 import traceback
 from typing import TYPE_CHECKING
 
-from airflow.utils.session import create_session
+from airflow import settings
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@
             log.warning('File: "%s", %s , in %s', filename, line_number, name)
             if line:
                 log.warning("  %s", line.strip())
-        __current_task_instance_session = create_session()
+        __current_task_instance_session = settings.Session()
     return __current_task_instance_session
 
 
diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst
index 739bc6f..dd6f42b 100644
--- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst
+++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst
@@ -53,7 +53,7 @@
 
     In the above example, ``values`` received by ``sum_it`` is an aggregation of all values returned by each mapped instance of ``add_one``. However, since it is impossible to know how many instances of ``add_one`` we will have in advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves each individual value only when asked. Therefore, if you run ``print(values)`` directly, you would get something like this::
 
-        LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one')
+        LazySelectSequence([15 items])
 
     You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but since this would eagerly load values from *all* of the referenced upstream mapped tasks, you must be aware of the potential performance implications if the mapped number is large.
 
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 8afc1ab..11d833a 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -73,7 +73,7 @@
 from airflow.models.taskmap import TaskMap
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.variable import Variable
-from airflow.models.xcom import LazyXComAccess, XCom
+from airflow.models.xcom import LazyXComSelectSequence, XCom
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import PythonOperator
@@ -93,6 +93,7 @@
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
+from airflow.utils.task_instance_session import set_current_task_instance_session
 from airflow.utils.types import DagRunType
 from airflow.utils.xcom import XCOM_RETURN_KEY
 from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER
@@ -4355,20 +4356,22 @@
     run: DagRun = dag_maker.create_dagrun()
     run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session)
 
-    query = session.query(XCom.value).filter_by(
-        dag_id=run.dag_id,
-        run_id=run.run_id,
-        task_id="t",
-        map_index=-1,
-        key="xxx",
-    )
-
-    original = LazyXComAccess.build_from_xcom_query(query)
-    processed = pickle.loads(pickle.dumps(original))
+    with set_current_task_instance_session(session=session):
+        original = LazyXComSelectSequence.from_select(
+            select(XCom.value).filter_by(
+                dag_id=run.dag_id,
+                run_id=run.run_id,
+                task_id="t",
+                map_index=-1,
+                key="xxx",
+            ),
+            order_by=(),
+        )
+        processed = pickle.loads(pickle.dumps(original))
 
     # After the object went through pickling, the underlying ORM query should be
     # replaced by one backed by a literal SQL string with all variables binded.
-    sql_lines = [line.strip() for line in str(processed._query.statement.compile(None)).splitlines()]
+    sql_lines = [line.strip() for line in str(processed._select_asc.compile(None)).splitlines()]
     assert sql_lines == _get_lazy_xcom_access_expected_sql_lines()
 
     assert len(processed) == 1
@@ -4398,7 +4401,7 @@
 
     # Simply pulling the joined XCom value should not deserialize.
     joined = ti_2.xcom_pull("task_1", session=session)
-    assert isinstance(joined, LazyXComAccess)
+    assert isinstance(joined, LazyXComSelectSequence)
     assert mock_deserialize_value.call_count == 0
 
     # Only when we go through the iterable does deserialization happen.