blob: 2ec9ce2e09fb8bce4b78e71b520526fd57088353 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Cross-language enum types: named, frozen, ordinal-indexed singletons.
An ``Enum`` subclass has one of two usage modes, distinguished by whether
its ``type_key`` is already registered in the FFI type system:
* **Closed Python enum** — fresh ``type_key``, variants declared once in
the class body. Behavior matches ``enum.Enum``.
* **Cross-language registry** — ``type_key`` also registered in C++ (or
another Python module). Python and C++ both contribute variants to
the same per-class registry, and consumers attach *extensible
attributes* to variants from any module at any time.
See :class:`Enum` for declaration forms and :meth:`Enum.def_attr` for
extensible attributes.
Storage layout (mirrors ``include/tvm/ffi/enum.h``):
* ``__ffi_enum_entries__`` — ``Dict[str, Enum]``, name → variant.
* ``__ffi_enum_attrs__`` — ``Dict[str, List[Any]]``, extensible-attr
name → column indexed by each variant's ordinal.
"""
from __future__ import annotations
import sys
import typing
from collections.abc import Callable, Iterator
from typing import Any, ClassVar
from typing_extensions import dataclass_transform
from .. import core
from ..container import Dict, List
from ..core import Object
from .c_class import c_class
from .field import Field, field
from .py_class import py_class
__all__ = [
"ENUM_ATTRS_ATTR",
"ENUM_ENTRIES_ATTR",
"ENUM_VALUE_ENTRIES_ATTR",
"Enum",
"EnumAttrMap",
"IntEnum",
"StrEnum",
"auto",
"entry",
]
#: TypeAttr column storing ``Dict[str, Enum]`` (instance name → singleton).
ENUM_ENTRIES_ATTR = "__ffi_enum_entries__"
#: TypeAttr column storing ``Dict[str, List[Any]]`` of per-variant attrs.
ENUM_ATTRS_ATTR = "__ffi_enum_attrs__"
#: TypeAttr column storing ``Dict[Any, Enum]`` (payload → singleton) —
#: parallel to :data:`ENUM_ENTRIES_ATTR` but keyed by the user-visible
#: payload (``IntEnum.value`` / ``StrEnum.value``). Populated by every
#: creator of a payload-carrying variant; consumed by FFI converters for
#: O(1) payload → variant resolution.
ENUM_VALUE_ENTRIES_ATTR = "__ffi_enum_value_entries__"
# ---------------------------------------------------------------------------
# entry() sentinel
# ---------------------------------------------------------------------------
class _EnumEntry:
"""Sentinel produced by :func:`entry`; consumed by ``Enum.__init_subclass__``.
Holds the positional and keyword arguments forwarded to the subclass's
``__init__`` when the variant is materialized. ``_value`` and ``_name``
are auto-assigned (dense ordinal and class-body name) and must not
appear in the captured arguments.
"""
__slots__ = ("args", "kwargs")
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args: tuple[Any, ...] = args
self.kwargs: dict[str, Any] = kwargs
def __repr__(self) -> str:
parts = [repr(a) for a in self.args]
parts.extend(f"{k}={v!r}" for k, v in self.kwargs.items())
return f"entry({', '.join(parts)})"
def entry(*args: Any, **kwargs: Any) -> Any:
"""Declare a new enum variant with values for its declared fields.
``entry(...)`` is a class-body sentinel; it never produces a real
instance. At class creation, :meth:`Enum.__init_subclass__` scans
for these sentinels and, for each one, constructs a singleton variant
by forwarding the captured positional and keyword arguments to the
subclass's ``__init__``, together with an auto-assigned
:attr:`~Enum._value` (dense ordinal) and :attr:`~Enum._name`
(class-body name).
Prefer :func:`auto` when a variant has no declared fields beyond the
auto-assigned ordinal and name — it expresses intent without the
empty-arg-list noise.
When the enum's ``type_key`` is C++-backed (registered via
``refl::ObjectDef``), only keyword arguments are supported — field
values are assigned via reflected setters keyed by name. Passing
positional arguments in that case raises :class:`TypeError`.
``entry(_value=...)`` and ``entry(_name=...)`` always raise
:class:`TypeError` because those fields are auto-assigned.
Examples
--------
Variant with declared fields:
.. code-block:: python
from typing import ClassVar
class Activation(Enum, type_key="my.Activation"):
output_zero: bool
is_monotonic: bool
relu: ClassVar[Activation] = entry(output_zero=True, is_monotonic=True)
gelu: ClassVar[Activation] = entry(output_zero=False, is_monotonic=False)
Returns
-------
object
An opaque sentinel. The declared return type is ``Any`` so that
``ClassVar[Cls] = entry(...)`` type-checks even though the sentinel
is not a real ``Cls``.
"""
return _EnumEntry(*args, **kwargs)
def auto() -> Any:
"""Declare a new enum variant with no declared fields.
Semantically equivalent to :func:`entry` called with no arguments but
reads more clearly for the common case where a variant differs from
its siblings only by name and ordinal. The resulting singleton has
only the auto-assigned :attr:`~Enum._value` and :attr:`~Enum._name`.
``auto()`` registers a *new* Python-side variant; it is not the right
tool for binding to a pre-existing C++-registered entry (use a bare
``ClassVar[Cls]`` annotation for that — see :class:`Enum`).
Examples
--------
.. code-block:: python
class Status(Enum, type_key="my.Status"):
ok = auto()
err = auto()
retry = auto()
assert Status.ok._value == 0
assert Status.err._name == "err"
Returns
-------
object
An opaque sentinel, the same kind returned by :func:`entry`. The
declared return type is ``Any`` so that both ``name = auto()`` and
``name: ClassVar[Cls] = auto()`` type-check.
"""
return _EnumEntry()
# ---------------------------------------------------------------------------
# Class-level helpers
# ---------------------------------------------------------------------------
class _ClassProperty:
"""Read-only descriptor whose getter receives the owning class.
Used for ``attr_dict`` so it works as a class-level attribute access
(e.g., ``Op.attr_dict["has_side_effects"]``) without needing a metaclass.
"""
__slots__ = ("_fget",)
def __init__(self, fget: Callable[[type], Any]) -> None:
self._fget = fget
def __get__(self, instance: Any, owner: type | None = None) -> Any:
cls = owner if owner is not None else type(instance)
return self._fget(cls)
def _install_payload_enum_behaviors(cls: type, *, user_defined_repr: bool) -> None:
"""Install stdlib-like instance behavior for ``IntEnum`` / ``StrEnum``.
Payload enums have two identities that plain ``Enum`` does not: the FFI
singleton identity (``Priority.low.same_as(...)``) and a user-visible raw
payload (``Priority.low.value == 10``). These methods make the Python
surface follow the payload where users expect stdlib-like behavior, while
keeping the object identity available through ``same_as``.
This runs after ``py_class`` / ``c_class`` because those decorators may
install generated object dunders. Installing here lets payload enums
override generated repr/equality/hash defaults, but still preserves methods
explicitly written in the enum subclass body.
"""
def _payload_enum_eq(self: Enum, other: object) -> Any:
"""Compare payload enum variants by payload and accept raw payloads.
Two variants of the same enum class compare like their ``value`` fields,
and a variant can compare directly with its raw payload type
(``Priority.low == 10``). Unrelated objects return ``NotImplemented``
so Python can try reflected comparison or fall back normally.
"""
if isinstance(other, type(self)):
return self.value == other.value # type: ignore[attr-defined]
payload_type = getattr(type(self), "__ffi_enum_payload_value_type__", None)
if payload_type is not None and isinstance(other, payload_type):
return self.value == other # type: ignore[attr-defined]
return NotImplemented
def _payload_enum_ne(self: Enum, other: object) -> Any:
"""Negate payload equality while preserving ``NotImplemented`` semantics."""
if (eq := _payload_enum_eq(self, other)) is not NotImplemented:
return not eq
return NotImplemented
def _payload_enum_str(self: Enum) -> str:
"""Render as the raw payload, matching ``int`` / ``str`` enum expectations."""
return str(self.value) # type: ignore[attr-defined]
def _payload_enum_hash(self: Enum) -> int:
"""Hash like the raw payload so equality and hashing stay consistent."""
return hash(self.value) # type: ignore[attr-defined]
_PAYLOAD_ENUM_DUNDERS = {
"__eq__": _payload_enum_eq,
"__ne__": _payload_enum_ne,
"__str__": _payload_enum_str,
"__hash__": _payload_enum_hash,
}
def _payload_enum_name(self: Enum) -> str:
"""Expose the declaration name without making plain ``Enum`` grow ``name``."""
return str(self._name)
def _payload_enum_repr(self: Enum) -> str:
"""Render as ``Class.member``, matching Python payload enum conventions."""
return f"{type(self).__name__}.{self.name}" # ty: ignore[unresolved-attribute]
if "name" not in cls.__dict__:
cls.name = property(_payload_enum_name) # ty: ignore[unresolved-attribute]
if not user_defined_repr:
cls.__repr__ = _payload_enum_repr # type: ignore[attr-defined]
for name, method in _PAYLOAD_ENUM_DUNDERS.items():
if name not in cls.__dict__:
setattr(cls, name, method)
# ---------------------------------------------------------------------------
# Enum base + EnumAttrMap
# ---------------------------------------------------------------------------
class _EnumMeta(type(Object)):
"""Metaclass for :class:`Enum` that wires class-level iteration and length.
Kept as a dedicated subclass of ``type(Object)`` so ``iter(MyEnum)`` and
``len(MyEnum)`` only resolve on :class:`Enum` subclasses, rather than
mutating the shared ``Object`` metaclass globally.
"""
def __iter__(cls) -> Iterator[Any]:
return iter(_ordered_entries(cls)) # ty: ignore[unresolved-attribute]
def __len__(cls) -> int:
return len(_ordered_entries(cls)) # ty: ignore[unresolved-attribute]
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
"""Construct enum variants from existing variants, payloads, or names.
Normal object construction is still delegated to ``Object`` / dataclass
initialization. The single-argument path is intercepted to match
stdlib enum lookup behavior:
* ``Priority(Priority.low)`` returns the same variant.
* ``Priority(10)`` returns the member whose payload value is ``10``.
* ``Priority("low")`` returns the member declared with that name.
Payload lookup is only enabled for classes with
``__ffi_enum_payload_value_type__`` so plain enums do not accidentally
interpret arbitrary class-body values as public payloads.
"""
if not kwargs and len(args) == 1:
value = args[0]
if isinstance(value, cls):
return value
payload_value_type = getattr(cls, "__ffi_enum_payload_value_type__", None)
if payload_value_type is not None and isinstance(value, payload_value_type):
value_entries = _value_entries_dict(cls)
if value_entries is not None and value in value_entries:
return value_entries[value]
if isinstance(value, str):
entries = _entries_dict(cls)
if entries is not None and value in entries:
return entries[value]
return super().__call__(*args, **kwargs)
@dataclass_transform(
eq_default=False,
order_default=False,
frozen_default=True,
field_specifiers=(Field, field, entry, auto),
)
@c_class("ffi.Enum", init=False)
class Enum(Object, metaclass=_EnumMeta):
"""A named-singleton registry with cross-language identity.
Subclasses declare variants: frozen, named, ordinal-indexed
singletons — the familiar enum pattern. Unlike ``enum.Enum``, an
``Enum`` subclass bound to an FFI-registered ``type_key`` has an
**open variant set**: C++ translation units and other Python modules
binding the same ``type_key`` can contribute variants to the shared
registry. Per-variant metadata can also be attached post-hoc via
:meth:`def_attr` as an *extensible attribute*, outside the class
definition.
For **closed, Python-only enums**, use a fresh ``type_key`` with
:func:`auto` / :func:`entry` — behavior matches ``enum.Enum``.
Attributes
----------
_value : int
Dense ordinal assigned at registration (0-indexed per class).
_name : str
The variant's string name key (e.g., ``"Add"`` for ``Op.Add``).
Closed Python enum
------------------
Pick a fresh ``type_key`` and list variants with :func:`auto` or
:func:`entry`. The variant set is fixed at class-definition time.
.. code-block:: python
class Priority(Enum, type_key="my.Priority"):
low = auto()
medium = auto()
high = auto()
# Variants with declared fields — values supplied via entry(...).
class Activation(Enum, type_key="my.Activation"):
output_zero: bool
is_monotonic: bool
relu: ClassVar[Activation] = entry(output_zero=True, is_monotonic=True)
gelu: ClassVar[Activation] = entry(output_zero=False, is_monotonic=False)
Cross-language registry
-----------------------
When ``type_key`` is already registered (typically by C++), the
Python class *binds* to the existing type rather than creating a
new one. Bare ``ClassVar[Cls]`` annotations bind to variants
already registered on the C++ side; :func:`entry` / :func:`auto`
still register fresh Python variants whose ordinals extend past the
C++ ones. All variants — regardless of origin — land in the same
per-class registry and are visible to every binder of the same
``type_key``.
.. code-block:: python
# Registered in C++ via refl::EnumDef<VariantObj>("Alpha")... .
class Variant(Enum, type_key="testing.TestEnumVariant"):
Alpha: ClassVar[Variant] # binds to C++-registered "Alpha"
Beta: ClassVar[Variant] # binds to C++-registered "Beta"
Declaration forms
-----------------
Four shapes are supported in the class body:
1. ``name = auto()`` — new variant with no declared fields.
2. ``name: ClassVar[Cls] = entry(**kwargs)`` — new variant; ``kwargs``
populate declared fields.
3. ``name = entry(**kwargs)`` — same as (2), without the ``ClassVar``
annotation.
4. ``name: ClassVar[Cls]`` — in cross-language mode, binds to an
existing C++-registered variant (error if unknown); otherwise
registers a new Python variant with only the auto-assigned
:attr:`_value` and :attr:`_name` (equivalent to ``name = auto()``).
Integer literals (``ok = 0``) are rejected: :attr:`_value` is
auto-assigned, so a user-supplied ordinal would either silently
duplicate or conflict. ``entry(_value=...)`` and ``entry(_name=...)``
raise :class:`TypeError` at class-body time.
Differences from ``enum.Enum``
------------------------------
* **Same**: hidden identity fields :attr:`_name`, :attr:`_value`,
iteration, identity
comparison; closed-set behavior when ``type_key`` is fresh.
* **Extended**: ``entry(**kwargs)`` replaces the tuple-RHS idiom;
``dataclass_transform`` gives native type-checker support; open
registry when ``type_key`` is shared across languages.
* **Different**: :attr:`_value` is always the ordinal; :meth:`def_attr`
adds extensible attributes outside the class schema. Use
:class:`IntEnum` / :class:`StrEnum` when you need a user-visible
``value`` field.
* **Not provided**: ``Flag`` / ``IntFlag``, member aliasing,
``_missing_`` hook.
Subclasses inherit :meth:`get`, :meth:`all_entries`, :meth:`def_attr`,
and the ``attr_dict`` class-level view.
"""
__slots__ = ()
_value: int
_name: str
__ffi_enum_payload_value_type__: object = None
def __init_subclass__(
cls,
*,
type_key: str | None = None,
frozen: bool = True,
init: bool = True,
repr: bool = True,
**kwargs: Any,
) -> None:
super().__init_subclass__(**kwargs)
if type_key is None:
return
user_defined_repr = "__repr__" in cls.__dict__
binders, python_entries = _collect_entry_declarations(cls)
cxx_backed = core._object_type_key_to_index(type_key) is not None
if cxx_backed:
c_class(type_key, init=init, repr=repr)(cls)
else:
py_class(type_key, frozen=frozen)(cls)
if getattr(cls, "__ffi_enum_payload_value_type__", None) is not None:
_install_payload_enum_behaviors(cls, user_defined_repr=user_defined_repr)
_resolve_entries(cls, binders, python_entries, type_key=type_key, cxx_backed=cxx_backed)
@classmethod
def get(cls, name: str) -> Enum:
"""Return the variant named *name*, or raise :class:`KeyError`."""
entries = _entries_dict(cls)
if entries is not None and name in entries:
return entries[name]
raise KeyError(f"{cls.__name__} has no variant named {name!r}")
@classmethod
def all_entries(cls) -> Iterator[Enum]:
"""Iterate over all variants, in ordinal (``_value``) order."""
return iter(_ordered_entries(cls))
@_ClassProperty
def attr_dict(cls: type) -> Any:
"""Live ``Dict[str, List[Any]]`` backing every extensible attribute.
The outer dict is keyed by extensible-attribute name; each value
is a list indexed by variant ordinal. Prefer :meth:`def_attr`
for normal per-variant reads and writes; this property is for
bulk inspection and for reading values written by C++
``EnumDef::set_attr``.
"""
return _attrs_dict(cls) or Dict({})
@classmethod
def def_attr(
cls,
name: str,
*,
default: Any = core.MISSING,
) -> EnumAttrMap:
"""Declare an *extensible attribute* column on this enum.
Extensible attributes let any consumer associate per-variant data
outside the enum's class-body schema — a lowering function
attached to an operator by a code generator, a cost model
registered only on some targets, a documentation string added
after the fact. Writes are last-write-wins for the same
``(variant, name)`` pair and visible to every consumer that
calls :meth:`def_attr` with the same name on the same enum,
including C++ code writing via ``EnumDef::set_attr``.
Extensible attributes differ from **declared fields**:
==================== ======================== ==========================
Concept Lives on Added by
==================== ======================== ==========================
Declared field The variant object Enum author, in class body
Extensible attribute ``__ffi_enum_attrs__`` Any consumer, any time
==================== ======================== ==========================
Rule of thumb: if the data is part of *what a variant is*,
declare a field in the class body; if it's part of *what a
consumer wants to know*, attach it with :meth:`def_attr`.
Parameters
----------
name
The extensible-attribute name (e.g., ``"has_side_effects"``).
Writes go to ``attr_dict[name]`` as a list indexed by each
variant's ordinal.
default
Value returned by ``attr[variant]`` when nothing was
registered for that variant. Left as ``MISSING`` to raise
:class:`KeyError` on unset variants instead. The default
is a property of *this* :class:`EnumAttrMap` view, not of
the underlying column: calling :meth:`def_attr` twice with
the same ``name`` but different defaults creates two views
that share every explicit write but may disagree on unset
variants — e.g., ``Op.def_attr("cost", default=0)`` and
``Op.def_attr("cost", default=-1)`` return ``0`` and ``-1``
respectively for a variant that was never written to.
Returns
-------
EnumAttrMap
Mutable view over the column. Use ``variant in attr`` to
distinguish an explicit write from a default-hit. ``None``
is reserved as the "unset" sentinel (matching C++
``EnumDef::set_attr`` padding), so ``attr[variant] = None``
raises :class:`TypeError` — store a typed wrapper (e.g. a
``0``/``False`` flag) when you need a falsy-but-present
value.
Notes
-----
:meth:`def_attr` is not a way to add fields to the enum's
schema, subclass frozen variants, or bypass the frozen-instance
invariant via ``setattr`` — for that, declare a field in the
class body instead.
"""
return EnumAttrMap(cls, name, default=default)
class EnumAttrMap:
"""Mutable per-variant view over an extensible-attribute column.
Returned by :meth:`Enum.def_attr`. Writes go to a ``List[Any]``
column keyed by extensible-attribute name inside the per-class
``__ffi_enum_attrs__`` dict. The list is indexed by each variant's
ordinal (``variant._value``) and padded with ``None`` as new variants
are registered. The column is shared across every consumer —
including C++ code writing via ``EnumDef::set_attr`` — and the data
is not a field on the variant object. See :meth:`Enum.def_attr` for
full semantics.
``None`` is reserved as the column's "unset" sentinel (matching the
C++ ``Any(nullptr)`` padding used by ``EnumDef::set_attr``), so
:meth:`__setitem__` rejects ``None`` with :class:`TypeError` — an
explicit ``attr[variant] = None`` would otherwise be
indistinguishable from never-written and surprise ``variant in attr``
/ :meth:`__getitem__` readers. To "clear" a previously written
value, register a mutable container once and mutate it in place.
"""
__slots__ = ("_default", "_enum_cls", "_name")
def __init__(self, enum_cls: type, name: str, *, default: Any = core.MISSING) -> None:
self._enum_cls = enum_cls
self._name = name
self._default = default
def _ordinal_of(self, variant: object) -> int:
if not isinstance(variant, self._enum_cls):
raise TypeError(
f"{self._enum_cls.__name__}.def_attr({self._name!r}) expects a "
f"{self._enum_cls.__name__} variant, got {type(variant).__name__}"
)
return int(variant._value) # type: ignore[attr-defined]
def _column(self, *, create: bool) -> Any | None:
"""Return the ``List[Any]`` column for this attribute; create if missing.
Returns ``None`` when ``create`` is false and the column doesn't exist.
"""
attrs = _attrs_dict(self._enum_cls) if not create else _ensure_attrs_dict(self._enum_cls)
if attrs is None:
return None
if self._name in attrs:
return attrs[self._name]
if not create:
return None
col = List([])
attrs[self._name] = col
return col
def __setitem__(self, variant: object, value: Any) -> None:
if value is None:
raise TypeError(
f"{self._enum_cls.__name__}.def_attr({self._name!r}): "
f"None is reserved as the 'unset' sentinel for extensible "
f"attributes and cannot be written explicitly."
)
ordinal = self._ordinal_of(variant)
col = self._column(create=True)
assert col is not None # create=True always materialises the column.
while len(col) <= ordinal:
col.append(None)
col[ordinal] = value
def __getitem__(self, variant: object) -> Any:
ordinal = self._ordinal_of(variant)
col = self._column(create=False)
if col is not None and ordinal < len(col):
v = col[ordinal]
if v is not None:
return v
if self._default is core.MISSING:
raise KeyError(
f"{self._enum_cls.__name__}.{variant._name} has no " # type: ignore[attr-defined]
f"extensible attribute {self._name!r} set"
)
return self._default
def __contains__(self, variant: object) -> bool:
if not isinstance(variant, self._enum_cls):
return False
try:
ordinal = self._ordinal_of(variant)
except TypeError:
return False
col = self._column(create=False)
return col is not None and ordinal < len(col) and col[ordinal] is not None
def get(self, variant: object, default: Any = None) -> Any:
"""Return the value for *variant*, or *default* if unset or foreign."""
if not isinstance(variant, self._enum_cls):
return default
try:
return self[variant]
except KeyError:
return default
@property
def name(self) -> str:
"""The extensible-attribute name passed to :meth:`Enum.def_attr`."""
return self._name
class IntEnum(Enum):
"""Enum variant base with a user-visible ``value: int`` field."""
__slots__ = ()
value: int
def __init_subclass__(
cls,
*,
type_key: str | None = None,
frozen: bool = True,
init: bool = True,
repr: bool = True,
**kwargs: Any,
) -> None:
_prepare_payload_enum_subclass(cls, value_type=int, base_name="IntEnum")
super().__init_subclass__(
type_key=type_key,
frozen=frozen,
init=init,
repr=repr,
**kwargs,
)
class StrEnum(Enum):
"""Enum variant base with a user-visible ``value: str`` field."""
__slots__ = ()
value: str
def __init_subclass__(
cls,
*,
type_key: str | None = None,
frozen: bool = True,
init: bool = True,
repr: bool = True,
**kwargs: Any,
) -> None:
_prepare_payload_enum_subclass(cls, value_type=str, base_name="StrEnum")
super().__init_subclass__(
type_key=type_key,
frozen=frozen,
init=init,
repr=repr,
**kwargs,
)
# ---------------------------------------------------------------------------
# TypeAttr accessors
# ---------------------------------------------------------------------------
def _entries_dict(cls: type) -> Any:
type_info = getattr(cls, "__tvm_ffi_type_info__", None)
if type_info is None:
return None
return core._lookup_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR)
def _attrs_dict(cls: type) -> Any:
type_info = getattr(cls, "__tvm_ffi_type_info__", None)
if type_info is None:
return None
return core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
def _value_entries_dict(cls: type) -> Any:
"""Return the live ``Dict[Any, Enum]`` payload-to-variant index, or None."""
type_info = getattr(cls, "__tvm_ffi_type_info__", None)
if type_info is None:
return None
return core._lookup_type_attr(type_info.type_index, ENUM_VALUE_ENTRIES_ATTR)
def _ordered_entries(cls: type) -> list[Any]:
"""Return all variants ordered by ordinal."""
entries = _entries_dict(cls)
if entries is None:
return []
ordered: list[Any] = [None] * len(entries)
for inst in entries.values():
ordered[int(inst._value)] = inst
return ordered
def _ensure_entries_dict(cls: type) -> Any:
"""Return the live ``__ffi_enum_entries__`` dict, registering it if absent."""
type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
entries = core._lookup_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR)
if entries is not None:
return entries
entries = Dict({})
core._register_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR, entries)
# Re-read so mutations go through the ref owned by the registry.
return core._lookup_type_attr(type_info.type_index, ENUM_ENTRIES_ATTR)
def _ensure_attrs_dict(cls: type) -> Any:
"""Return the live ``__ffi_enum_attrs__`` dict, registering it if absent."""
type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
attrs = core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
if attrs is not None:
return attrs
attrs = Dict({})
core._register_type_attr(type_info.type_index, ENUM_ATTRS_ATTR, attrs)
return core._lookup_type_attr(type_info.type_index, ENUM_ATTRS_ATTR)
def _ensure_value_entries_dict(cls: type) -> Any:
"""Return the live ``__ffi_enum_value_entries__`` dict, registering it if absent."""
type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
entries = core._lookup_type_attr(type_info.type_index, ENUM_VALUE_ENTRIES_ATTR)
if entries is not None:
return entries
entries = Dict({})
core._register_type_attr(type_info.type_index, ENUM_VALUE_ENTRIES_ATTR, entries)
return core._lookup_type_attr(type_info.type_index, ENUM_VALUE_ENTRIES_ATTR)
# ---------------------------------------------------------------------------
# Class-body scanning + entry materialisation
# ---------------------------------------------------------------------------
def _collect_entry_declarations(
cls: type,
) -> tuple[list[str], dict[str, _EnumEntry]]:
"""Scan *cls.__dict__* for variant declarations.
Returns ``(binders, python_entries)`` in declaration order:
* *binders* — names annotated as ``ClassVar[Cls]`` with no assigned value.
Each either binds to an existing C++-registered entry with the same
name or registers a new blank Python entry.
* *python_entries* — names assigned an ``entry(...)`` sentinel (with or
without a ``ClassVar`` annotation). Each registers a new Python entry
using the captured args/kwargs.
Matched assignments are removed from ``cls.__dict__`` so that
``@c_class`` / ``@py_class`` don't misinterpret them as field defaults or
class constants.
"""
annotations = _own_annotations(cls)
dict_keys = set(cls.__dict__.keys())
binders: list[str] = []
ordinary_fields: set[str] = set()
payload_value_type = getattr(cls, "__ffi_enum_payload_value_type__", None)
for name, ann in annotations.items():
if name.startswith("_"):
continue
if _is_class_var(ann):
if name not in dict_keys:
binders.append(name)
else:
ordinary_fields.add(name)
python_entries: dict[str, _EnumEntry] = {}
for name, value in list(cls.__dict__.items()):
if name.startswith("_"):
continue
if isinstance(value, _EnumEntry):
python_entries[name] = value
try:
delattr(cls, name)
except AttributeError:
pass
elif payload_value_type is not None and name not in ordinary_fields:
if isinstance(value, (staticmethod, classmethod, property)) or callable(value):
continue
python_entries[name] = _EnumEntry(value=value)
try:
delattr(cls, name)
except AttributeError:
pass
return binders, python_entries
def _resolve_entries(
cls: type,
binders: list[str],
python_entries: dict[str, _EnumEntry],
*,
type_key: str,
cxx_backed: bool,
) -> None:
"""Materialise *binders* and *python_entries* into class-attribute singletons.
Processing order matches declaration order: ``binders`` first (because
their annotations appear before any class-body assignments), then
``python_entries`` in their class-body order. Each newly registered
entry gets a dense ordinal equal to the current entries-dict size, so
ordinals stay compact and stable across registrations.
A cxx-backed enum (``type_key`` was already registered in the FFI type
system before this Python subclass was created) supports mixing C++ and
Python entries: bare ``ClassVar[Cls]`` binders must name an existing
C++-registered entry, but ``entry(...)``/``auto()`` sentinels may add
fresh Python-side entries whose ordinals extend past the C++ entries.
"""
entries = _ensure_entries_dict(cls)
payload_value_type = getattr(cls, "__ffi_enum_payload_value_type__", None)
for name in binders:
if name in entries:
# Already materialised — either C++-registered or previously bound.
instance = entries[name]
setattr(cls, name, instance)
_index_payload(cls, instance, payload_value_type)
continue
if cxx_backed:
raise _cxx_backed_unknown_binder_error(cls, name, type_key, entries)
ordinal = len(entries)
instance = _instantiate(cls, args=(), kwargs={}, ordinal=ordinal, name=name)
entries[name] = instance
setattr(cls, name, instance)
_index_payload(cls, instance, payload_value_type)
for name, e in python_entries.items():
if name in entries:
raise RuntimeError(
f"Duplicate enum entry {name!r} for {cls.__name__}: already "
f"registered as ordinal {int(entries[name]._value)}."
)
if "_value" in e.kwargs or "_name" in e.kwargs:
raise TypeError(
f"{cls.__name__}.{name}: `_value` and `_name` are auto-assigned "
f"and must not appear in entry(...) arguments."
)
ordinal = len(entries)
instance = _instantiate_entry(
cls, entry=e, ordinal=ordinal, name=name, cxx_backed=cxx_backed
)
entries[name] = instance
setattr(cls, name, instance)
_index_payload(cls, instance, payload_value_type)
def _index_payload(cls: type, instance: Any, payload_value_type: type | None) -> None:
"""Record ``(instance.value → instance)`` in the value-entries column.
No-op for non-payload enums. For payload enums, the first writer of a
given payload wins — matches the "first-match" semantics of the linear
scan that FFI converters fall back to when this column is incomplete.
"""
if payload_value_type is None:
return
payload = getattr(instance, "value", None)
if payload is None:
return
value_entries = _ensure_value_entries_dict(cls)
if payload not in value_entries:
value_entries[payload] = instance
def _cxx_backed_unknown_binder_error(
cls: type,
name: str,
type_key: str,
entries: Any,
) -> RuntimeError:
"""Build a descriptive error for an unbindable bare ``ClassVar`` binder.
A bare ``ClassVar[Cls]`` annotation on a cxx-backed enum means "bind to
an existing C++ entry with this name" — if the C++ registry has no such
entry, the declaration is almost always a typo. For adding a *new*
Python-side variant on a cxx-backed enum, use ``entry(...)`` or
``auto()`` instead.
"""
known = list(entries.keys()) if entries is not None else []
known_str = ", ".join(repr(k) for k in known) if known else "<none>"
return RuntimeError(
f"Cannot bind enum variant {name!r} on {cls.__name__}: the FFI "
f"type {type_key!r} is already registered in C++ with entries "
f"[{known_str}], but has no C++ entry named {name!r}. "
f"Bare ``ClassVar[{cls.__name__}]`` binders on a C++-backed enum "
f"must name an entry already registered in C++; they cannot "
f"introduce new variants from Python. "
f"If this was a typo, double-check the spelling against the known "
f"entries above (`{name}: ClassVar[{cls.__name__}]`); if you meant "
f"to add a new Python-side variant, use `{name} = auto()` or "
f"`{name} = entry(...)` instead."
)
def _instantiate(
cls: type,
*,
args: tuple[Any, ...],
kwargs: dict[str, Any],
ordinal: int,
name: str,
) -> Any:
"""Construct a subclass instance with auto-assigned ``_value``/``_name``."""
merged = dict(kwargs)
merged["_value"] = ordinal
merged["_name"] = name
return cls(*args, **merged)
def _instantiate_entry(
cls: type,
*,
entry: _EnumEntry,
ordinal: int,
name: str,
cxx_backed: bool,
) -> Any:
"""Instantiate an enum entry and normalize construction errors."""
try:
if cxx_backed:
return _instantiate_cxx_backed(
cls, args=entry.args, kwargs=entry.kwargs, ordinal=ordinal, name=name
)
return _instantiate(cls, args=entry.args, kwargs=entry.kwargs, ordinal=ordinal, name=name)
except Exception as err:
raise TypeError(f"{cls.__name__}.{name}: invalid enum entry: {err}") from None
def _instantiate_cxx_backed(
cls: type,
*,
args: tuple[Any, ...],
kwargs: dict[str, Any],
ordinal: int,
name: str,
) -> Any:
"""Construct a new variant of a cxx-backed enum without going through ``__init__``.
C++-backed enums whose underlying type is registered with
``refl::init(false)`` (e.g. any subclass of ``EnumObj`` in C++) have no
``__ffi_init__``, so the usual ``cls(_value=..., _name=...)`` path is not
available. Mirror ``reflection::EnumDef`` by allocating a blank instance
via ``__ffi_new__`` and populating fields through the frozen-setter
escape hatch exposed on the reflected property descriptors.
"""
if args:
raise TypeError(
f"{cls.__name__}.{name}: positional `entry(...)` args are not "
f"supported when extending a C++-backed enum; use keyword "
f"arguments naming reflected fields."
)
type_info = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute]
ffi_new = core._lookup_type_attr(type_info.type_index, "__ffi_new__")
if ffi_new is None:
raise RuntimeError(
f"Cannot add Python enum variant {name!r} on {cls.__name__}: "
f"its C++ type has no ``__ffi_new__`` allocator registered, so "
f"blank instances cannot be created from Python."
)
instance = ffi_new()
for key in ("_value", "_name", *kwargs.keys()):
descriptor = getattr(cls, key, None)
if descriptor is None or not hasattr(descriptor, "set"):
raise TypeError(
f"{cls.__name__}.{name}: cannot set field {key!r} on a "
f"C++-backed enum — no reflected setter is available."
)
getattr(cls, "_value").set(instance, ordinal)
getattr(cls, "_name").set(instance, name)
for k, v in kwargs.items():
getattr(cls, k).set(instance, v)
return instance
# ---------------------------------------------------------------------------
# Annotation helpers
# ---------------------------------------------------------------------------
def _own_annotations(cls: type) -> dict[str, Any]:
"""Return *cls*'s own annotations dict (not inherited)."""
if sys.version_info >= (3, 14):
return dict(getattr(cls, "__annotations__", {}) or {})
return dict(cls.__dict__.get("__annotations__", {}))
def _is_class_var(annotation: Any) -> bool:
"""Return True if *annotation* is ``ClassVar`` or ``ClassVar[...]``."""
if annotation is ClassVar:
return True
if typing.get_origin(annotation) is ClassVar:
return True
if isinstance(annotation, str):
stripped = annotation.replace(" ", "")
return stripped.startswith("ClassVar") or stripped.startswith("typing.ClassVar")
return False
def _annotation_matches_expected(annotation: Any, expected_type: type) -> bool:
"""Return True if *annotation* matches the required payload-field type."""
if annotation is expected_type:
return True
if isinstance(annotation, str):
stripped = annotation.replace(" ", "")
expected = expected_type.__name__
return stripped in {expected, f"builtins.{expected}"}
return False
def _prepare_payload_enum_subclass(
cls: type[Enum], *, value_type: type[Any], base_name: str
) -> None:
"""Inject and validate the user-visible ``value`` field for payload enums."""
if "value" in cls.__dict__:
raise TypeError(
f"{base_name} reserves `value` as a declared field; use "
"`entry(value=...)` or `<variant> = <payload>` to assign each "
"variant's payload."
)
annotations = _own_annotations(cls)
existing = annotations.get("value")
if existing is not None:
if _is_class_var(existing):
raise TypeError(f"{base_name} reserves `value` as a declared field, not a ClassVar.")
if not _annotation_matches_expected(existing, value_type):
raise TypeError(
f"{base_name} requires `value: {value_type.__name__}` on subclasses, "
f"got {existing!r}."
)
return
updated = _own_annotations(cls)
updated["value"] = value_type
cls.__annotations__ = updated
cls.__ffi_enum_payload_value_type__ = value_type