| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """The ``c_class`` decorator: register_object + structural dunders.""" |
| |
| from __future__ import annotations |
| |
| import typing |
| from collections.abc import Callable |
| from typing import Any, TypeVar |
| |
| from typing_extensions import dataclass_transform |
| |
| from .field import Field |
| |
| _T = TypeVar("_T", bound=type) |
| |
| |
| def _attach_field_objects(cls: type, type_info: Any) -> None: |
| """Populate ``TypeField.dataclass_field`` for every own reflected field. |
| |
| ``@c_class`` fields originate from C++ reflection, so there is no |
| user-supplied :class:`Field`. We synthesize one per ``TypeField`` |
| and stash it on ``TypeField.dataclass_field`` so |
| :func:`~tvm_ffi.dataclasses.fields` can return it. |
| """ |
| try: |
| hints = typing.get_type_hints(cls) |
| except Exception: |
| hints = {} |
| for tf in type_info.fields: |
| f = Field( |
| name=tf.name, |
| _ty_schema=tf.ty, |
| default=tf.c_default, |
| default_factory=tf.c_default_factory, |
| frozen=tf.frozen, |
| init=tf.c_init, |
| repr=tf.c_repr, |
| hash=tf.c_hash, |
| compare=tf.c_compare, |
| kw_only=tf.c_kw_only, |
| structural_eq=tf.c_structural_eq, |
| doc=tf.doc, |
| ) |
| f.type = hints.get(tf.name) |
| tf.dataclass_field = f |
| |
| |
| @dataclass_transform(eq_default=False, order_default=False) |
| def c_class( |
| type_key: str, |
| *, |
| init: bool = True, |
| repr: bool = True, |
| eq: bool = False, |
| order: bool = False, |
| unsafe_hash: bool = False, |
| match_args: bool = True, |
| ) -> Callable[[_T], _T]: |
| """Register a C++ FFI class and install structural dunder methods. |
| |
| Combines :func:`~tvm_ffi.register_object` with structural comparison, |
| hashing, and ordering derived from the C++ reflection metadata. |
| User-defined dunders in the class body are never overwritten. |
| |
| Parameters |
| ---------- |
| type_key |
| The reflection key that identifies the C++ type in the FFI registry. |
| Must match a key already registered on the C++ side via |
| ``TVM_FFI_DECLARE_OBJECT_INFO``. |
| init |
| If True (default), install ``__init__`` from C++ reflection metadata. |
| The generated ``__init__`` respects ``Init()``, ``KwOnly()``, and |
| ``Default()`` traits declared on each C++ field. If the class body |
| already defines ``__init__``, it is kept. |
| repr |
| If True (default), install ``__repr__`` using |
| :func:`~tvm_ffi.core.object_repr`, which formats the object via |
| the C++ ``ReprPrint`` visitor. Skipped if the class body already |
| defines ``__repr__``. |
| eq |
| If True, install ``__eq__`` and ``__ne__`` using the C++ recursive |
| structural comparison (``RecursiveEq``). Returns ``NotImplemented`` |
| for unrelated types. Defaults to False. |
| order |
| If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__`` |
| using the C++ recursive comparators. Returns ``NotImplemented`` |
| for unrelated types. Defaults to False. |
| unsafe_hash |
| If True, install ``__hash__`` using ``RecursiveHash``. Called |
| *unsafe* because mutable fields contribute to the hash, so mutating |
| an object while it is in a set or dict key will break invariants. |
| Defaults to False. |
| match_args |
| If True (default), set ``__match_args__`` to a tuple of the |
| positional ``__init__`` field names (``init=True`` and not |
| ``kw_only``), enabling ``match`` statements. Ignored when the |
| class body already defines ``__match_args__``. |
| |
| Returns |
| ------- |
| Callable[[type], type] |
| A class decorator. |
| |
| Examples |
| -------- |
| Basic usage with default settings (``init`` and ``repr`` enabled): |
| |
| .. code-block:: python |
| |
| @c_class("my.Point") |
| class Point(Object): |
| x: float |
| y: float |
| |
| Enable structural equality, hashing, and ordering: |
| |
| .. code-block:: python |
| |
| @c_class("my.Point", eq=True, unsafe_hash=True, order=True) |
| class Point(Object): |
| x: float |
| y: float |
| |
| See Also |
| -------- |
| :func:`tvm_ffi.register_object` |
| Lower-level decorator that only registers the type without |
| installing structural dunders. |
| |
| """ |
| from .._dunder import _install_dataclass_dunders # noqa: PLC0415 |
| from ..registry import ( # noqa: PLC0415 |
| _warn_missing_field_annotations, |
| register_object, |
| ) |
| |
| def decorator(cls: _T) -> _T: |
| cls = register_object(type_key, init=False)(cls) |
| type_info = getattr(cls, "__tvm_ffi_type_info__", None) |
| assert type_info is not None |
| _warn_missing_field_annotations(cls, type_info, stacklevel=2) |
| _attach_field_objects(cls, type_info) |
| _install_dataclass_dunders( |
| cls, |
| init=init, |
| repr=repr, |
| eq=eq, |
| order=order, |
| unsafe_hash=unsafe_hash, |
| match_args=match_args, |
| ) |
| # Marker: distinguishes @c_class / @py_class types from FFI containers |
| # (Array, List, Map, Dict) that also have __tvm_ffi_type_info__ but are |
| # not dataclasses. Used by is_dataclass() in common.py. |
| setattr(cls, "__tvm_ffi_is_dataclass__", True) |
| return cls |
| |
| return decorator |