| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from __future__ import annotations |
| |
| import dataclasses |
| import datetime |
| import enum |
| import inspect |
| import logging |
| import os |
| import sys |
| import typing |
| from typing import List, Dict |
| |
| from pyfory.lib.mmh3 import hash_buffer |
| from pyfory.types import ( |
| TypeId, |
| int8, |
| int16, |
| int32, |
| int64, |
| fixed_int32, |
| fixed_int64, |
| tagged_int64, |
| uint8, |
| uint16, |
| uint32, |
| fixed_uint32, |
| uint64, |
| fixed_uint64, |
| tagged_uint64, |
| float32, |
| float64, |
| is_primitive_array_type, |
| is_list_type, |
| is_map_type, |
| get_primitive_type_size, |
| is_polymorphic_type, |
| is_primitive_type, |
| is_union_type, |
| ) |
| from pyfory.type_util import ( |
| TypeVisitor, |
| infer_field, |
| get_homogeneous_tuple_elem_type, |
| is_subclass, |
| get_type_hints, |
| unwrap_optional, |
| ) |
| from pyfory.serialization import Buffer |
| from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION |
| from pyfory.error import TypeNotCompatibleError |
| from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG |
| from pyfory.field import ( |
| ForyFieldMeta, |
| extract_field_meta, |
| validate_field_metas, |
| ) |
| |
| from pyfory import ( |
| Serializer, |
| BooleanSerializer, |
| ByteSerializer, |
| Int16Serializer, |
| Int32Serializer, |
| Int64Serializer, |
| Float32Serializer, |
| Float64Serializer, |
| StringSerializer, |
| ) |
| |
| logger = logging.getLogger(__name__) |
| |
| _MISSING_DEFAULT_INT_TYPES = { |
| int, |
| int8, |
| int16, |
| int32, |
| fixed_int32, |
| int64, |
| fixed_int64, |
| tagged_int64, |
| uint8, |
| uint16, |
| uint32, |
| fixed_uint32, |
| uint64, |
| fixed_uint64, |
| tagged_uint64, |
| } |
| |
| _MISSING_DEFAULT_FLOAT_TYPES = { |
| float, |
| float32, |
| float64, |
| } |
| |
| |
| @dataclasses.dataclass |
| class FieldInfo: |
| """Pre-computed field information for serialization.""" |
| |
| # Identity |
| name: str # Field name (snake_case) |
| index: int # Field index in the serialization order |
| type_hint: type # Type annotation |
| |
| # Fory metadata (from pyfory.field()) - used for hash computation |
| tag_id: int # -1 = use field name, >=0 = use tag ID |
| nullable: bool # Effective nullable flag (considers Optional[T]) |
| ref: bool # Field-level ref setting (for hash computation) |
| dynamic: bool # Whether type info is written for this field |
| |
| # Runtime flags (combines field metadata with global Fory config) |
| runtime_ref_tracking: bool # Actual ref tracking: field.ref AND fory.track_ref |
| |
| # Derived info |
| type_id: int # Fory TypeId |
| serializer: Serializer # Field serializer |
| unwrapped_type: type # Type with Optional unwrapped |
| |
| |
| def _is_abstract_type(type_hint: type) -> bool: |
| """Check if a type is abstract (has abstract methods or is ABC subclass).""" |
| if type_hint is None: |
| return False |
| try: |
| # Check if it's an abstract class using inspect.isabstract |
| return inspect.isabstract(type_hint) |
| except TypeError: |
| # Not a class (e.g., generic type) |
| return False |
| |
| |
| def _default_field_meta(type_hint: type, field_nullable: bool = False) -> ForyFieldMeta: |
| """Returns default field metadata for fields without pyfory.field(). |
| |
| A field is considered nullable if: |
| 1. It's Optional[T], OR |
| 2. Global field_nullable is True |
| |
| For ref, defaults to False to preserve original serialization behavior. |
| Non-nullable complex fields use write_no_ref (no ref header in buffer). |
| Users can explicitly set ref=True in pyfory.field() to enable ref tracking. |
| |
| For dynamic, defaults to None (auto-detect): |
| - Abstract classes: always True (type info must be written) |
| - Concrete types use type-id based dynamic detection |
| """ |
| unwrapped_type, is_optional = unwrap_optional(type_hint) |
| nullable = is_optional or field_nullable |
| # Default ref=False to preserve original serialization behavior where non-nullable |
| # fields use write_no_ref. Users can explicitly set ref=True in pyfory.field() |
| # to enable per-field ref tracking when fory.track_ref is enabled. |
| # Default dynamic=None for auto-detection based on type and mode |
| return ForyFieldMeta(id=-1, nullable=nullable, ref=False, ignore=False, dynamic=None) |
| |
| |
| def _extract_field_infos( |
| fory, |
| clz: type, |
| type_hints: dict, |
| ) -> tuple[list[FieldInfo], dict[str, ForyFieldMeta]]: |
| """ |
| Extract FieldInfo list from a dataclass. |
| |
| This handles: |
| - Extracting field metadata from pyfory.field() annotations |
| - Filtering out ignored fields |
| - Computing effective nullable based on Optional[T] |
| - Computing runtime ref tracking based on global config |
| - Inheritance: parent fields first, subclass fields override parent fields |
| |
| Returns: |
| Tuple of (field_infos, field_metas) where field_metas maps field name to ForyFieldMeta |
| """ |
| if not dataclasses.is_dataclass(clz): |
| # For non-dataclass, return empty - will use legacy path |
| return [], {} |
| |
| # Collect fields from class hierarchy (parent first, child last) |
| # Child fields override parent fields with same name |
| all_fields: Dict[str, dataclasses.Field] = {} |
| for klass in clz.__mro__[::-1]: # Reverse MRO: base classes first |
| if dataclasses.is_dataclass(klass) and klass is not clz: |
| for f in dataclasses.fields(klass): |
| all_fields[f.name] = f |
| # Add current class fields (override parent) |
| for f in dataclasses.fields(clz): |
| all_fields[f.name] = f |
| |
| # Extract field metas and filter ignored fields |
| field_metas: Dict[str, ForyFieldMeta] = {} |
| active_fields: List[tuple] = [] |
| |
| # Check if fory has field_nullable global setting |
| global_field_nullable = getattr(fory, "field_nullable", False) |
| |
| for field_name, dc_field in all_fields.items(): |
| meta = extract_field_meta(dc_field) |
| if meta is None: |
| # Field without pyfory.field() - use defaults |
| # Auto-detect Optional[T] for nullable, also respect global field_nullable |
| field_type = type_hints.get(field_name, typing.Any) |
| meta = _default_field_meta(field_type, global_field_nullable) |
| |
| field_metas[field_name] = meta |
| |
| if not meta.ignore: |
| active_fields.append((field_name, dc_field)) |
| |
| # Validate field metas |
| validate_field_metas(clz, field_metas, type_hints) |
| |
| # Build FieldInfo list |
| field_infos: List[FieldInfo] = [] |
| visitor = StructFieldSerializerVisitor(fory) |
| global_ref_tracking = fory.track_ref |
| |
| for index, (field_name, dc_field) in enumerate(active_fields): |
| meta = field_metas[field_name] |
| type_hint = type_hints.get(field_name, typing.Any) |
| unwrapped_type, is_optional = unwrap_optional(type_hint) |
| |
| # Optional[T] should always be nullable regardless of explicit meta. |
| effective_nullable = meta.nullable or is_optional |
| |
| # Compute runtime ref tracking: field.ref AND global config |
| runtime_ref = meta.ref and global_ref_tracking |
| |
| # Infer serializer |
| serializer = infer_field(field_name, unwrapped_type, visitor, types_path=[]) |
| |
| # Get type_id from serializer |
| if serializer is not None: |
| type_id = fory.get_type_info(serializer.type_).type_id |
| else: |
| type_id = TypeId.UNKNOWN |
| |
| # Compute effective dynamic based on type. |
| # - Abstract classes: always True (type info must be written) |
| # - If explicitly set (not None): use that value |
| # - Otherwise: write type info for polymorphic types that are not registered by id |
| is_abstract = _is_abstract_type(unwrapped_type) |
| if is_abstract: |
| # Abstract classes always need type info |
| effective_dynamic = True |
| elif meta.dynamic is not None: |
| # Explicit configuration takes precedence |
| effective_dynamic = meta.dynamic |
| else: |
| # Registered-by-id types have stable serializers, so no per-field type info is needed. |
| effective_dynamic = is_polymorphic_type(type_id) and not fory.is_registered_by_id(unwrapped_type) |
| |
| field_info = FieldInfo( |
| name=field_name, |
| index=index, |
| type_hint=type_hint, |
| tag_id=meta.id, |
| nullable=effective_nullable, |
| ref=meta.ref, |
| dynamic=effective_dynamic, |
| runtime_ref_tracking=runtime_ref, |
| type_id=type_id, |
| serializer=serializer, |
| unwrapped_type=unwrapped_type, |
| ) |
| field_infos.append(field_info) |
| |
| return field_infos, field_metas |
| |
| |
| def resolve_missing_field_default( |
| dc_field: dataclasses.Field, |
| type_resolver, |
| type_hints: dict[str, typing.Any], |
| ) -> typing.Callable[[], typing.Any]: |
| type_hint = type_hints.get(dc_field.name, typing.Any) |
| unwrapped_type, is_optional = unwrap_optional(type_hint) |
| meta = extract_field_meta(dc_field) |
| effective_nullable = (meta.nullable if meta is not None else type_resolver.field_nullable) or is_optional |
| |
| if dc_field.default is not dataclasses.MISSING: |
| default_value = dc_field.default |
| if default_value is None and not effective_nullable and is_subclass(unwrapped_type, enum.Enum): |
| members = tuple(unwrapped_type) |
| if members: |
| default_value = members[0] |
| return lambda value=default_value: value |
| |
| if dc_field.default_factory is not dataclasses.MISSING: |
| return dc_field.default_factory |
| |
| if not effective_nullable: |
| origin = typing.get_origin(unwrapped_type) if hasattr(typing, "get_origin") else getattr(unwrapped_type, "__origin__", None) |
| origin = origin or unwrapped_type |
| if is_subclass(unwrapped_type, enum.Enum): |
| members = tuple(unwrapped_type) |
| if members: |
| default_value = members[0] |
| return lambda value=default_value: value |
| if origin is list or origin == typing.List: |
| return lambda: [] |
| if origin is set or origin == typing.Set: |
| return lambda: set() |
| if origin is dict or origin == typing.Dict: |
| return lambda: {} |
| if unwrapped_type is bool: |
| return lambda: False |
| if unwrapped_type in _MISSING_DEFAULT_INT_TYPES: |
| return lambda: 0 |
| if unwrapped_type in _MISSING_DEFAULT_FLOAT_TYPES: |
| return lambda: 0.0 |
| if unwrapped_type is str: |
| return lambda: "" |
| if unwrapped_type is bytes: |
| return lambda: b"" |
| return lambda: None |
| |
| |
| def _resolve_missing_field_default(dc_field, type_resolver, type_hints): |
| return resolve_missing_field_default(dc_field, type_resolver, type_hints) |
| |
| |
| def build_default_values_factory(type_resolver, type_hints, dc_fields=()): |
| return {dc_field.name: _resolve_missing_field_default(dc_field, type_resolver, type_hints) for dc_field in dc_fields} |
| |
| |
| class DataClassSerializer(Serializer): |
| _BASIC_SERIALIZERS = ( |
| BooleanSerializer, |
| ByteSerializer, |
| Int16Serializer, |
| Int32Serializer, |
| Int64Serializer, |
| Float32Serializer, |
| Float64Serializer, |
| StringSerializer, |
| ) |
| |
| def __init__( |
| self, |
| type_resolver, |
| clz: type, |
| field_names: List[str] = None, |
| serializers: List[Serializer] = None, |
| nullable_fields: Dict[str, bool] = None, |
| dynamic_fields: Dict[str, bool] = None, |
| ref_fields: Dict[str, bool] = None, |
| ): |
| super().__init__(type_resolver, clz) |
| |
| self._type_hints = get_type_hints(clz) |
| self._has_slots = hasattr(clz, "__slots__") |
| |
| self._fields_from_typedef = field_names is not None and serializers is not None |
| if self._fields_from_typedef: |
| self._field_names = list(field_names) |
| self._serializers = list(serializers) |
| self._nullable_fields = nullable_fields or {} |
| self._ref_fields = ref_fields or {} |
| self._dynamic_fields = dynamic_fields or {} |
| self._field_infos = [] |
| self._field_metas = {} |
| else: |
| self._field_infos, self._field_metas = _extract_field_infos(type_resolver, clz, self._type_hints) |
| if self._field_infos: |
| self._field_names = [fi.name for fi in self._field_infos] |
| self._serializers = [fi.serializer for fi in self._field_infos] |
| self._nullable_fields = {fi.name: fi.nullable for fi in self._field_infos} |
| self._ref_fields = {fi.name: fi.runtime_ref_tracking for fi in self._field_infos} |
| self._dynamic_fields = {fi.name: fi.dynamic for fi in self._field_infos} |
| else: |
| self._field_names = field_names or self._get_field_names(clz) |
| self._nullable_fields = nullable_fields or {} |
| self._ref_fields = {} |
| self._dynamic_fields = {} |
| if self._field_names and not self._nullable_fields: |
| for field_name in self._field_names: |
| if field_name in self._type_hints: |
| unwrapped_type, is_optional = unwrap_optional(self._type_hints[field_name]) |
| self._nullable_fields[field_name] = is_optional or not is_primitive_type(unwrapped_type) |
| self._serializers = serializers or [None] * len(self._field_names) |
| if serializers is None: |
| visitor = StructFieldSerializerVisitor(type_resolver) |
| for index, key in enumerate(self._field_names): |
| unwrapped_type, _ = unwrap_optional(self._type_hints.get(key, typing.Any)) |
| self._serializers[index] = infer_field(key, unwrapped_type, visitor, types_path=[]) |
| |
| self._unwrapped_hints = self._compute_unwrapped_hints() |
| if self._fields_from_typedef: |
| hash_str = compute_struct_fingerprint(self.type_resolver, self._field_names, self._serializers, self._nullable_fields, self._field_infos) |
| hash_bytes = hash_str.encode("utf-8") |
| if len(hash_bytes) == 0: |
| self._hash = 47 |
| else: |
| full_hash = hash_buffer(hash_bytes, seed=47)[0] |
| type_hash_32 = full_hash & 0xFFFFFFFF |
| if full_hash & 0x80000000: |
| type_hash_32 -= 0x100000000 |
| self._hash = type_hash_32 |
| else: |
| self._hash, self._field_names, self._serializers = compute_struct_meta( |
| self.type_resolver, self._field_names, self._serializers, self._nullable_fields, self._field_infos |
| ) |
| |
| self._field_name_interned = {name: sys.intern(name) for name in self._field_names} |
| self._current_class_field_names = set(self._get_field_names(self.type_)) |
| self._default_values_factory = ( |
| build_default_values_factory(self.type_resolver, self._type_hints, dataclasses.fields(self.type_)) |
| if dataclasses.is_dataclass(self.type_) |
| else {} |
| ) |
| self._missing_field_defaults = self._build_missing_field_defaults() |
| self._basic_field_flags = [ |
| (not self._dynamic_fields.get(field_name, False)) and isinstance(self._serializers[index], self._BASIC_SERIALIZERS) |
| for index, field_name in enumerate(self._field_names) |
| ] |
| |
| def _get_field_names(self, clz): |
| if hasattr(clz, "__dict__"): |
| if dataclasses.is_dataclass(clz): |
| return [field.name for field in dataclasses.fields(clz)] |
| return sorted(self._type_hints.keys()) |
| if hasattr(clz, "__slots__"): |
| slots = clz.__slots__ |
| if isinstance(slots, str): |
| return [slots] |
| return sorted(slots) |
| return [] |
| |
| def _compute_unwrapped_hints(self): |
| return {field_name: unwrap_optional(hint)[0] for field_name, hint in self._type_hints.items()} |
| |
| def _build_missing_field_defaults(self): |
| if not self.type_resolver.compatible or not self._default_values_factory: |
| return [] |
| missing_fields = self._current_class_field_names - set(self._field_names) |
| if not missing_fields: |
| return [] |
| return [(field_name, default_factory) for field_name, default_factory in self._default_values_factory.items() if field_name in missing_fields] |
| |
| def _write_field_value(self, write_context, serializer, field_value, is_nullable, is_dynamic, is_basic, is_tracking_ref): |
| if is_basic: |
| if is_nullable: |
| if field_value is None: |
| write_context.write_int8(NULL_FLAG) |
| else: |
| write_context.write_int8(NOT_NULL_VALUE_FLAG) |
| serializer.write(write_context, field_value) |
| else: |
| serializer.write(write_context, field_value) |
| return |
| if is_tracking_ref: |
| write_context.write_ref(field_value, serializer=None if is_dynamic else serializer) |
| return |
| if is_nullable: |
| if field_value is None: |
| write_context.write_int8(NULL_FLAG) |
| return |
| write_context.write_int8(NOT_NULL_VALUE_FLAG) |
| if is_dynamic: |
| write_context.write_no_ref(field_value) |
| else: |
| write_context.write_no_ref(field_value, serializer=serializer) |
| |
| def _read_field_value(self, read_context, serializer, is_nullable, is_dynamic, is_basic, is_tracking_ref): |
| if is_nullable and is_basic: |
| if read_context.read_int8() == NULL_FLAG: |
| return None |
| return serializer.read(read_context) |
| if is_basic: |
| return serializer.read(read_context) |
| if is_tracking_ref: |
| return read_context.read_ref(serializer=None if is_dynamic else serializer) |
| if is_nullable and read_context.read_int8() == NULL_FLAG: |
| return None |
| if is_dynamic: |
| return read_context.read_no_ref() |
| return read_context.read_no_ref(serializer=serializer) |
| |
| def write(self, write_context: Buffer, value): |
| if not self.type_resolver.compatible: |
| write_context.write_int32(self._hash) |
| value_dict = value.__dict__ if not self._has_slots else None |
| if value_dict is not None: |
| if self.type_resolver.compatible: |
| for index, field_name in enumerate(self._field_names): |
| interned_name = self._field_name_interned[field_name] |
| field_value = value_dict.get(interned_name) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| is_dynamic = self._dynamic_fields.get(field_name, False) |
| is_tracking_ref = self._ref_fields.get(field_name, False) |
| is_basic = self._basic_field_flags[index] |
| self._write_field_value(write_context, serializer, field_value, is_nullable, is_dynamic, is_basic, is_tracking_ref) |
| else: |
| for index, field_name in enumerate(self._field_names): |
| interned_name = self._field_name_interned[field_name] |
| field_value = value_dict[interned_name] |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| is_dynamic = self._dynamic_fields.get(field_name, False) |
| is_tracking_ref = self._ref_fields.get(field_name, False) |
| is_basic = self._basic_field_flags[index] |
| self._write_field_value(write_context, serializer, field_value, is_nullable, is_dynamic, is_basic, is_tracking_ref) |
| else: |
| if self.type_resolver.compatible: |
| for index, field_name in enumerate(self._field_names): |
| interned_name = self._field_name_interned[field_name] |
| field_value = getattr(value, interned_name, None) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| is_dynamic = self._dynamic_fields.get(field_name, False) |
| is_tracking_ref = self._ref_fields.get(field_name, False) |
| is_basic = self._basic_field_flags[index] |
| self._write_field_value(write_context, serializer, field_value, is_nullable, is_dynamic, is_basic, is_tracking_ref) |
| else: |
| for index, field_name in enumerate(self._field_names): |
| interned_name = self._field_name_interned[field_name] |
| field_value = getattr(value, interned_name) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| is_dynamic = self._dynamic_fields.get(field_name, False) |
| is_tracking_ref = self._ref_fields.get(field_name, False) |
| is_basic = self._basic_field_flags[index] |
| self._write_field_value(write_context, serializer, field_value, is_nullable, is_dynamic, is_basic, is_tracking_ref) |
| write_context.try_flush() |
| |
| def read(self, read_context): |
| if not self.type_resolver.strict: |
| read_context.policy.authorize_instantiation(self.type_) |
| if not self.type_resolver.compatible: |
| hash_ = read_context.read_int32() |
| if hash_ != self._hash: |
| raise TypeNotCompatibleError( |
| f"Hash {hash_} is not consistent with {self._hash} for type {self.type_}", |
| ) |
| obj = self.type_.__new__(self.type_) |
| read_context.reference(obj) |
| obj_dict = obj.__dict__ if not self._has_slots else None |
| for index, field_name in enumerate(self._field_names): |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| is_dynamic = self._dynamic_fields.get(field_name, False) |
| is_tracking_ref = self._ref_fields.get(field_name, False) |
| is_basic = self._basic_field_flags[index] |
| field_value = self._read_field_value(read_context, serializer, is_nullable, is_dynamic, is_basic, is_tracking_ref) |
| if field_name not in self._current_class_field_names: |
| continue |
| interned_name = self._field_name_interned[field_name] |
| if obj_dict is not None: |
| obj_dict[interned_name] = field_value |
| else: |
| setattr(obj, interned_name, field_value) |
| |
| if self._missing_field_defaults: |
| for field_name, default_factory in self._missing_field_defaults: |
| value = default_factory() |
| if obj_dict is not None: |
| obj_dict[field_name] = value |
| else: |
| setattr(obj, field_name, value) |
| read_context.shrink_input_buffer() |
| return obj |
| |
| |
| class DataClassStubSerializer(DataClassSerializer): |
| def __init__(self, type_resolver, clz: type): |
| Serializer.__init__(self, type_resolver, clz) |
| |
| def write(self, write_context, value): |
| self._replace().write(write_context, value) |
| |
| def read(self, read_context): |
| return self._replace().read(read_context) |
| |
| def _replace(self): |
| typeinfo = self.type_resolver.get_type_info(self.type_) |
| typeinfo.serializer = DataClassSerializer(self.type_resolver, self.type_) |
| return typeinfo.serializer |
| |
| |
| basic_types = { |
| bool, |
| # Signed integers |
| int8, |
| int16, |
| int32, |
| fixed_int32, |
| int64, |
| fixed_int64, |
| tagged_int64, |
| # Unsigned integers |
| uint8, |
| uint16, |
| uint32, |
| fixed_uint32, |
| uint64, |
| fixed_uint64, |
| tagged_uint64, |
| # Floats |
| float32, |
| float64, |
| # Python native types |
| int, |
| float, |
| str, |
| bytes, |
| datetime.datetime, |
| datetime.date, |
| datetime.time, |
| } |
| |
| |
| class StructFieldSerializerVisitor(TypeVisitor): |
| def __init__( |
| self, |
| type_resolver, |
| ): |
| self.type_resolver = type_resolver |
| |
| def visit_list(self, field_name, elem_type, types_path=None): |
| from pyfory.serializer import ListSerializer # Local import |
| from pyfory.type_util import unwrap_ref |
| |
| # Infer type recursively for type such as List[Dict[str, str]] |
| elem_type, elem_ref_override = unwrap_ref(elem_type) |
| elem_serializer = infer_field("item", elem_type, self, types_path=types_path) |
| return ListSerializer(self.type_resolver, list, elem_serializer, elem_ref_override) |
| |
| def visit_set(self, field_name, elem_type, types_path=None): |
| from pyfory.serializer import SetSerializer # Local import |
| from pyfory.type_util import unwrap_ref |
| |
| # Infer type recursively for type such as Set[Dict[str, str]] |
| elem_type, elem_ref_override = unwrap_ref(elem_type) |
| elem_serializer = infer_field("item", elem_type, self, types_path=types_path) |
| return SetSerializer(self.type_resolver, set, elem_serializer, elem_ref_override) |
| |
| def visit_tuple(self, field_name, elem_types, types_path=None): |
| from pyfory.serializer import TupleSerializer # Local import |
| from pyfory.type_util import unwrap_ref |
| |
| elem_type = get_homogeneous_tuple_elem_type(elem_types) |
| if elem_type is not None: |
| elem_type, elem_ref_override = unwrap_ref(elem_type) |
| elem_serializer = infer_field("item", elem_type, self, types_path=types_path) |
| return TupleSerializer(self.type_resolver, tuple, elem_serializer, elem_ref_override) |
| return TupleSerializer(self.type_resolver, tuple) |
| |
| def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| from pyfory.serializer import MapSerializer # Local import |
| from pyfory.type_util import unwrap_ref |
| |
| # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| key_type, key_ref_override = unwrap_ref(key_type) |
| value_type, value_ref_override = unwrap_ref(value_type) |
| key_serializer = infer_field("key", key_type, self, types_path=types_path) |
| value_serializer = infer_field("value", value_type, self, types_path=types_path) |
| return MapSerializer( |
| self.type_resolver, |
| dict, |
| key_serializer, |
| value_serializer, |
| key_ref_override, |
| value_ref_override, |
| ) |
| |
| def visit_customized(self, field_name, type_, types_path=None): |
| if issubclass(type_, enum.Enum): |
| return self.type_resolver.get_serializer(type_) |
| # For custom types (dataclasses, etc.), try to get or create serializer |
| # This enables field-level serializer resolution for types like inner structs |
| typeinfo = self.type_resolver.get_type_info(type_, create=False) |
| if typeinfo is not None: |
| return typeinfo.serializer |
| return None |
| |
| def visit_other(self, field_name, type_, types_path=None): |
| if is_subclass(type_, enum.Enum): |
| return self.type_resolver.get_serializer(type_) |
| if type_ not in basic_types and not is_primitive_array_type(type_): |
| return None |
| return self.type_resolver.get_serializer(type_) |
| |
| |
| _UNKNOWN_TYPE_ID = -1 |
| |
| |
| def _sort_fields(type_resolver, field_names, serializers, nullable_map=None, field_infos_list=None): |
| (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) = group_fields( |
| type_resolver, field_names, serializers, nullable_map, field_infos_list |
| ) |
| all_types = boxed_types + nullable_boxed_types + internal_types + collection_types + set_types + map_types + other_types |
| return [t[2] for t in all_types], [t[1] for t in all_types] |
| |
| |
| def group_fields(type_resolver, field_names, serializers, nullable_map=None, field_infos_list=None): |
| nullable_map = nullable_map or {} |
| field_info_map = {} |
| if field_infos_list: |
| field_info_map = {fi.name: fi for fi in field_infos_list} |
| boxed_types = [] |
| nullable_boxed_types = [] |
| collection_types = [] |
| set_types = [] |
| map_types = [] |
| internal_types = [] |
| other_types = [] |
| type_ids = [] |
| for field_name, serializer in zip(field_names, serializers): |
| fi = field_info_map.get(field_name) |
| tag_id = fi.tag_id if fi else -1 |
| if tag_id >= 0: |
| sort_key = (0, str(tag_id), "") |
| else: |
| sort_key = (1, field_name, "") |
| if serializer is None: |
| other_types.append((_UNKNOWN_TYPE_ID, serializer, field_name, sort_key)) |
| else: |
| type_ids.append( |
| ( |
| type_resolver.get_type_info(serializer.type_).type_id, |
| serializer, |
| field_name, |
| sort_key, |
| ) |
| ) |
| for type_id, serializer, field_name, sort_key in type_ids: |
| if is_union_type(type_id): |
| type_id = TypeId.UNION |
| is_nullable = nullable_map.get(field_name, False) |
| if is_primitive_type(type_id): |
| container = nullable_boxed_types if is_nullable else boxed_types |
| elif type_id == TypeId.SET: |
| container = set_types |
| elif type_id == TypeId.LIST or is_list_type(serializer.type_): |
| container = collection_types |
| elif is_map_type(serializer.type_): |
| container = map_types |
| elif is_polymorphic_type(type_id) or type_id in {TypeId.ENUM, TypeId.NAMED_ENUM} or is_union_type(type_id): |
| container = other_types |
| elif type_id >= TypeId.BOUND: |
| # Native mode user-registered types have type_id >= BOUND |
| container = other_types |
| else: |
| assert TypeId.UNKNOWN < type_id < TypeId.BOUND, (type_id,) |
| container = internal_types |
| container.append((type_id, serializer, field_name, sort_key)) |
| |
| def sorter(item): |
| return item[0], item[3] |
| |
| def numeric_sorter(item): |
| id_ = item[0] |
| compress = id_ in { |
| # Signed compressed types |
| TypeId.VARINT32, |
| TypeId.VARINT64, |
| TypeId.TAGGED_INT64, |
| # Unsigned compressed types |
| TypeId.VAR_UINT32, |
| TypeId.VAR_UINT64, |
| TypeId.TAGGED_UINT64, |
| } |
| # Sort by: compress flag, -size (largest first), -type_id (higher type ID first), field_name |
| # Java sorts by size (largest first), then by primitive type ID (descending) |
| return int(compress), -get_primitive_type_size(id_), -id_, item[3] |
| |
| boxed_types = sorted(boxed_types, key=numeric_sorter) |
| nullable_boxed_types = sorted(nullable_boxed_types, key=numeric_sorter) |
| collection_types = sorted(collection_types, key=sorter) |
| set_types = sorted(set_types, key=sorter) |
| internal_types = sorted(internal_types, key=sorter) |
| map_types = sorted(map_types, key=sorter) |
| other_types = sorted(other_types, key=lambda item: item[3]) |
| return (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) |
| |
| |
| def compute_struct_fingerprint(type_resolver, field_names, serializers, nullable_map=None, field_infos_list=None): |
| """ |
| Computes the fingerprint string for a struct type used in schema versioning. |
| |
| Fingerprint Format: |
| Each field contributes: <field_id_or_name>,<type_id>,<ref>,<nullable>; |
| Fields are sorted by tag ID (if >=0) or field name (if id=-1). |
| |
| Field Components: |
| - field_id_or_name: Tag ID as string if id >= 0, otherwise field name |
| - type_id: Fory TypeId as decimal string (e.g., "4" for INT32) |
| - ref: "1" if field has ref=True in pyfory.field(), "0" otherwise |
| (based on field annotation, NOT runtime config) |
| - nullable: "1" if null flag is written, "0" otherwise |
| |
| Example fingerprints: |
| With tag IDs: "0,4,0,0;1,12,0,1;2,0,0,1;" |
| With field names: "age,4,0,0;email,12,0,1;name,9,0,0;" |
| |
| This format is consistent across Go, Java, Rust, C++, and Python implementations. |
| """ |
| if nullable_map is None: |
| nullable_map = {} |
| |
| # Build field info list for fingerprint: (sort_key, field_id_or_name, type_id, ref_flag, nullable_flag) |
| fp_fields = [] |
| |
| # Build a lookup for field_infos by name if available |
| field_info_map = {} |
| if field_infos_list: |
| field_info_map = {fi.name: fi for fi in field_infos_list} |
| |
| for i, field_name in enumerate(field_names): |
| serializer = serializers[i] |
| |
| # Get field metadata if available |
| fi = field_info_map.get(field_name) |
| tag_id = fi.tag_id if fi else -1 |
| ref_flag = "1" if (fi and fi.ref) else "0" |
| |
| if serializer is None: |
| type_id = TypeId.UNKNOWN |
| # For unknown serializers, use nullable from map (defaults to False for xlang) |
| nullable_flag = "1" if nullable_map.get(field_name, False) else "0" |
| else: |
| type_id = type_resolver.get_type_info(serializer.type_).type_id |
| if is_union_type(type_id): |
| # customized types can't be detected at compile time for some languages |
| type_id = TypeId.UNKNOWN |
| is_nullable = nullable_map.get(field_name, False) |
| |
| # For polymorphic or enum types, set type_id to UNKNOWN but preserve nullable from map |
| if is_polymorphic_type(type_id) or type_id in { |
| TypeId.ENUM, |
| TypeId.NAMED_ENUM, |
| }: |
| type_id = TypeId.UNKNOWN |
| |
| # Use nullable from map - for xlang, this is already computed correctly |
| # (False by default except for Optional[T] or explicit annotation) |
| nullable_flag = "1" if is_nullable else "0" |
| |
| # Determine field identifier for fingerprint |
| if tag_id >= 0: |
| field_id_or_name = str(tag_id) |
| # Sort by tag ID string (lexicographic) for tag ID fields |
| sort_key = (0, field_id_or_name, "") # 0 = tag ID fields come first |
| else: |
| field_id_or_name = field_name |
| # Sort by field name (lexicographic) for name-based fields |
| sort_key = (1, field_name, "") # 1 = name fields come after |
| |
| fp_fields.append((sort_key, field_id_or_name, type_id, ref_flag, nullable_flag)) |
| |
| # Sort fields: tag ID fields first (by ID), then name fields (lexicographically) |
| fp_fields.sort(key=lambda x: x[0]) |
| |
| # Build fingerprint string |
| hash_parts = [] |
| for _, field_id_or_name, type_id, ref_flag, nullable_flag in fp_fields: |
| hash_parts.append(f"{field_id_or_name},{type_id},{ref_flag},{nullable_flag};") |
| |
| return "".join(hash_parts) |
| |
| |
| def compute_struct_meta(type_resolver, field_names, serializers, nullable_map=None, field_infos_list=None): |
| """ |
| Computes struct metadata including version hash, sorted field names, and serializers. |
| |
| Uses compute_struct_fingerprint to build the fingerprint string, then hashes it |
| with MurmurHash3 using seed 47, and takes the low 32 bits as signed int32. |
| |
| This provides the cross-language struct version ID used by class version checking, |
| consistent with Go, Java, Rust, and C++ implementations. |
| """ |
| (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) = group_fields( |
| type_resolver, field_names, serializers, nullable_map, field_infos_list |
| ) |
| |
| # Compute fingerprint string using the new format with field infos |
| hash_str = compute_struct_fingerprint(type_resolver, field_names, serializers, nullable_map, field_infos_list) |
| hash_bytes = hash_str.encode("utf-8") |
| |
| # Handle empty hash_bytes (no fields or all fields are unknown/dynamic) |
| if len(hash_bytes) == 0: |
| full_hash = 47 # Use seed as default hash for empty structs |
| else: |
| full_hash = hash_buffer(hash_bytes, seed=47)[0] |
| type_hash_32 = full_hash & 0xFFFFFFFF |
| if full_hash & 0x80000000: |
| # If the sign bit is set, it's a negative number in 2's complement |
| # Subtract 2^32 to get the correct negative value |
| type_hash_32 = type_hash_32 - 0x100000000 |
| assert type_hash_32 != 0 |
| if os.environ.get("ENABLE_FORY_DEBUG_OUTPUT", "").lower() in ("1", "true"): |
| print(f'[Python][fory-debug] struct version fingerprint="{hash_str}" version hash={type_hash_32}') |
| |
| # Flatten all groups in correct order (already sorted from group_fields) |
| all_types = boxed_types + nullable_boxed_types + internal_types + collection_types + set_types + map_types + other_types |
| sorted_field_names = [f[2] for f in all_types] |
| sorted_serializers = [f[1] for f in all_types] |
| |
| return type_hash_32, sorted_field_names, sorted_serializers |
| |
| |
| class StructTypeIdVisitor(TypeVisitor): |
| def __init__( |
| self, |
| type_resolver, |
| cls, |
| ): |
| self.type_resolver = type_resolver |
| self.cls = cls |
| |
| def visit_list(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as List[Dict[str, str]] |
| elem_ids = infer_field("item", elem_type, self, types_path=types_path) |
| return TypeId.LIST, elem_ids |
| |
| def visit_set(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as Set[Dict[str, str]] |
| elem_ids = infer_field("item", elem_type, self, types_path=types_path) |
| return TypeId.SET, elem_ids |
| |
| def visit_tuple(self, field_name, elem_types, types_path=None): |
| elem_type = get_homogeneous_tuple_elem_type(elem_types) |
| if elem_type is None: |
| return TypeId.LIST, [TypeId.UNKNOWN] |
| elem_ids = infer_field("item", elem_type, self, types_path=types_path) |
| return TypeId.LIST, elem_ids |
| |
| def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| key_ids = infer_field("key", key_type, self, types_path=types_path) |
| value_ids = infer_field("value", value_type, self, types_path=types_path) |
| return TypeId.MAP, key_ids, value_ids |
| |
| def visit_customized(self, field_name, type_, types_path=None): |
| typeinfo = self.type_resolver.get_type_info(type_, create=False) |
| if typeinfo is None: |
| return [TypeId.UNKNOWN] |
| return [typeinfo.type_id] |
| |
| def visit_other(self, field_name, type_, types_path=None): |
| if is_subclass(type_, enum.Enum): |
| return [self.type_resolver.get_type_info(type_).type_id] |
| if type_ not in basic_types and not is_primitive_array_type(type_): |
| return None, None |
| typeinfo = self.type_resolver.get_type_info(type_) |
| return [typeinfo.type_id] |
| |
| |
| class StructTypeVisitor(TypeVisitor): |
| def __init__(self, cls): |
| self.cls = cls |
| |
| def visit_list(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as List[Dict[str, str]] |
| elem_types = infer_field("item", elem_type, self, types_path=types_path) |
| return typing.List, elem_types |
| |
| def visit_set(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as Set[Dict[str, str]] |
| elem_types = infer_field("item", elem_type, self, types_path=types_path) |
| return typing.Set, elem_types |
| |
| def visit_tuple(self, field_name, elem_types, types_path=None): |
| elem_type = get_homogeneous_tuple_elem_type(elem_types) |
| if elem_type is None: |
| return tuple, None |
| elem_types_ = infer_field("item", elem_type, self, types_path=types_path) |
| return tuple, elem_types_ |
| |
| def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| key_types = infer_field("key", key_type, self, types_path=types_path) |
| value_types = infer_field("value", value_type, self, types_path=types_path) |
| return typing.Dict, key_types, value_types |
| |
| def visit_customized(self, field_name, type_, types_path=None): |
| return [type_] |
| |
| def visit_other(self, field_name, type_, types_path=None): |
| return [type_] |
| |
| |
| def get_field_names(clz, type_hints=None): |
| if hasattr(clz, "__dict__"): |
| # Regular object with __dict__ |
| # We can't know the fields without an instance, so we rely on type hints |
| if type_hints is None: |
| type_hints = get_type_hints(clz) |
| return sorted(type_hints.keys()) |
| elif hasattr(clz, "__slots__"): |
| # Object with __slots__ |
| return sorted(clz.__slots__) |
| return [] |
| |
| |
| if ENABLE_FORY_CYTHON_SERIALIZATION: |
| from pyfory.serialization import ( |
| DataClassSerializer as CythonDataClassSerializer, |
| DataClassStubSerializer as CythonDataClassStubSerializer, |
| ) |
| |
| DataClassSerializer = CythonDataClassSerializer |
| DataClassStubSerializer = CythonDataClassStubSerializer |