blob: 501c56b0b9cf8fb4eff189ffc8927300bf083e37 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The ``c_class`` decorator: register_object + structural dunders."""
from __future__ import annotations
import typing
from collections.abc import Callable
from typing import Any, TypeVar
from typing_extensions import dataclass_transform
from .field import Field
_T = TypeVar("_T", bound=type)
def _attach_field_objects(cls: type, type_info: Any) -> None:
"""Populate ``TypeField.dataclass_field`` for every own reflected field.
``@c_class`` fields originate from C++ reflection, so there is no
user-supplied :class:`Field`. We synthesize one per ``TypeField``
and stash it on ``TypeField.dataclass_field`` so
:func:`~tvm_ffi.dataclasses.fields` can return it.
"""
try:
hints = typing.get_type_hints(cls)
except Exception:
hints = {}
for tf in type_info.fields:
f = Field(
name=tf.name,
_ty_schema=tf.ty,
default=tf.c_default,
default_factory=tf.c_default_factory,
frozen=tf.frozen,
init=tf.c_init,
repr=tf.c_repr,
hash=tf.c_hash,
compare=tf.c_compare,
kw_only=tf.c_kw_only,
structural_eq=tf.c_structural_eq,
doc=tf.doc,
)
f.type = hints.get(tf.name)
tf.dataclass_field = f
@dataclass_transform(eq_default=False, order_default=False)
def c_class(
type_key: str,
*,
init: bool = True,
repr: bool = True,
eq: bool = False,
order: bool = False,
unsafe_hash: bool = False,
match_args: bool = True,
) -> Callable[[_T], _T]:
"""Register a C++ FFI class and install structural dunder methods.
Combines :func:`~tvm_ffi.register_object` with structural comparison,
hashing, and ordering derived from the C++ reflection metadata.
User-defined dunders in the class body are never overwritten.
Parameters
----------
type_key
The reflection key that identifies the C++ type in the FFI registry.
Must match a key already registered on the C++ side via
``TVM_FFI_DECLARE_OBJECT_INFO``.
init
If True (default), install ``__init__`` from C++ reflection metadata.
The generated ``__init__`` respects ``Init()``, ``KwOnly()``, and
``Default()`` traits declared on each C++ field. If the class body
already defines ``__init__``, it is kept.
repr
If True (default), install ``__repr__`` using
:func:`~tvm_ffi.core.object_repr`, which formats the object via
the C++ ``ReprPrint`` visitor. Skipped if the class body already
defines ``__repr__``.
eq
If True, install ``__eq__`` and ``__ne__`` using the C++ recursive
structural comparison (``RecursiveEq``). Returns ``NotImplemented``
for unrelated types. Defaults to False.
order
If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``
using the C++ recursive comparators. Returns ``NotImplemented``
for unrelated types. Defaults to False.
unsafe_hash
If True, install ``__hash__`` using ``RecursiveHash``. Called
*unsafe* because mutable fields contribute to the hash, so mutating
an object while it is in a set or dict key will break invariants.
Defaults to False.
match_args
If True (default), set ``__match_args__`` to a tuple of the
positional ``__init__`` field names (``init=True`` and not
``kw_only``), enabling ``match`` statements. Ignored when the
class body already defines ``__match_args__``.
Returns
-------
Callable[[type], type]
A class decorator.
Examples
--------
Basic usage with default settings (``init`` and ``repr`` enabled):
.. code-block:: python
@c_class("my.Point")
class Point(Object):
x: float
y: float
Enable structural equality, hashing, and ordering:
.. code-block:: python
@c_class("my.Point", eq=True, unsafe_hash=True, order=True)
class Point(Object):
x: float
y: float
See Also
--------
:func:`tvm_ffi.register_object`
Lower-level decorator that only registers the type without
installing structural dunders.
"""
from .._dunder import _install_dataclass_dunders # noqa: PLC0415
from ..registry import ( # noqa: PLC0415
_warn_missing_field_annotations,
register_object,
)
def decorator(cls: _T) -> _T:
cls = register_object(type_key, init=False)(cls)
type_info = getattr(cls, "__tvm_ffi_type_info__", None)
assert type_info is not None
_warn_missing_field_annotations(cls, type_info, stacklevel=2)
_attach_field_objects(cls, type_info)
_install_dataclass_dunders(
cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
match_args=match_args,
)
# Marker: distinguishes @c_class / @py_class types from FFI containers
# (Array, List, Map, Dict) that also have __tvm_ffi_type_info__ but are
# not dataclasses. Used by is_dataclass() in common.py.
setattr(cls, "__tvm_ffi_is_dataclass__", True)
return cls
return decorator