| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from __future__ import annotations |
| |
| import dataclasses |
| import datetime |
| import enum |
| import itertools |
| import logging |
| import os |
| import typing |
| from typing import List, Dict |
| |
| from pyfory.lib.mmh3 import hash_buffer |
| from pyfory.type import ( |
| TypeVisitor, |
| infer_field, |
| TypeId, |
| int8, |
| int16, |
| int32, |
| int64, |
| float32, |
| float64, |
| is_py_array_type, |
| is_list_type, |
| is_map_type, |
| get_primitive_type_size, |
| is_polymorphic_type, |
| is_primitive_type, |
| is_subclass, |
| unwrap_optional, |
| ) |
| from pyfory.buffer import Buffer |
| from pyfory.codegen import ( |
| gen_write_nullable_basic_stmts, |
| gen_read_nullable_basic_stmts, |
| compile_function, |
| ) |
| from pyfory.error import TypeNotCompatibleError |
| from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG |
| from pyfory.field import ( |
| ForyFieldMeta, |
| extract_field_meta, |
| validate_field_metas, |
| ) |
| |
| from pyfory import ( |
| Serializer, |
| BooleanSerializer, |
| ByteSerializer, |
| Int16Serializer, |
| Int32Serializer, |
| Int64Serializer, |
| Float32Serializer, |
| Float64Serializer, |
| StringSerializer, |
| ) |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| @dataclasses.dataclass |
| class FieldInfo: |
| """Pre-computed field information for serialization.""" |
| |
| # Identity |
| name: str # Field name (snake_case) |
| index: int # Field index in the serialization order |
| type_hint: type # Type annotation |
| |
| # Fory metadata (from pyfory.field()) - used for hash computation |
| tag_id: int # -1 = use field name, >=0 = use tag ID |
| nullable: bool # Effective nullable flag (considers Optional[T]) |
| ref: bool # Field-level ref setting (for hash computation) |
| |
| # Runtime flags (combines field metadata with global Fory config) |
| runtime_ref_tracking: bool # Actual ref tracking: field.ref AND fory.ref_tracking |
| |
| # Derived info |
| type_id: int # Fory TypeId |
| serializer: Serializer # Field serializer |
| unwrapped_type: type # Type with Optional unwrapped |
| |
| |
| def _default_field_meta(type_hint: type, field_nullable: bool = False, xlang: bool = False) -> ForyFieldMeta: |
| """Returns default field metadata for fields without pyfory.field(). |
| |
| For native mode, a field is considered nullable if: |
| 1. It's Optional[T], OR |
| 2. It's a non-primitive type (all reference types can be None), OR |
| 3. Global field_nullable is True |
| |
| For xlang mode, a field is nullable only if: |
| 1. It's Optional[T] |
| |
| For ref, defaults to False to preserve original serialization behavior. |
| Non-nullable complex fields use xwrite_no_ref (no ref header in buffer). |
| Users can explicitly set ref=True in pyfory.field() to enable ref tracking. |
| """ |
| unwrapped_type, is_optional = unwrap_optional(type_hint) |
| if xlang: |
| # For xlang: nullable=False by default, except for Optional[T] types |
| nullable = is_optional |
| else: |
| # For native: Non-primitive types (str, list, dict, etc.) are all nullable by default |
| nullable = is_optional or not is_primitive_type(unwrapped_type) or field_nullable |
| # Default ref=False to preserve original serialization behavior where non-nullable |
| # fields use xwrite_no_ref. Users can explicitly set ref=True in pyfory.field() |
| # to enable per-field ref tracking when fory.ref_tracking is enabled. |
| return ForyFieldMeta(id=-1, nullable=nullable, ref=False, ignore=False) |
| |
| |
| def _extract_field_infos( |
| fory, |
| clz: type, |
| type_hints: dict, |
| xlang: bool = False, |
| ) -> tuple[list[FieldInfo], dict[str, ForyFieldMeta]]: |
| """ |
| Extract FieldInfo list from a dataclass. |
| |
| This handles: |
| - Extracting field metadata from pyfory.field() annotations |
| - Filtering out ignored fields |
| - Computing effective nullable based on Optional[T] |
| - Computing runtime ref tracking based on global config |
| - Inheritance: parent fields first, subclass fields override parent fields |
| |
| Args: |
| xlang: If True, use xlang defaults (nullable=False except for Optional[T]) |
| |
| Returns: |
| Tuple of (field_infos, field_metas) where field_metas maps field name to ForyFieldMeta |
| """ |
| if not dataclasses.is_dataclass(clz): |
| # For non-dataclass, return empty - will use legacy path |
| return [], {} |
| |
| # Collect fields from class hierarchy (parent first, child last) |
| # Child fields override parent fields with same name |
| all_fields: dict[str, dataclasses.Field] = {} |
| for klass in clz.__mro__[::-1]: # Reverse MRO: base classes first |
| if dataclasses.is_dataclass(klass) and klass is not clz: |
| for f in dataclasses.fields(klass): |
| all_fields[f.name] = f |
| # Add current class fields (override parent) |
| for f in dataclasses.fields(clz): |
| all_fields[f.name] = f |
| |
| # Extract field metas and filter ignored fields |
| field_metas: dict[str, ForyFieldMeta] = {} |
| active_fields: list[tuple[str, dataclasses.Field]] = [] |
| |
| # Check if fory has field_nullable global setting |
| global_field_nullable = getattr(fory, "field_nullable", False) |
| |
| for field_name, dc_field in all_fields.items(): |
| meta = extract_field_meta(dc_field) |
| if meta is None: |
| # Field without pyfory.field() - use defaults |
| # Auto-detect Optional[T] for nullable, also respect global field_nullable |
| field_type = type_hints.get(field_name, typing.Any) |
| meta = _default_field_meta(field_type, global_field_nullable, xlang=xlang) |
| |
| field_metas[field_name] = meta |
| |
| if not meta.ignore: |
| active_fields.append((field_name, dc_field)) |
| |
| # Validate field metas |
| validate_field_metas(clz, field_metas, type_hints) |
| |
| # Build FieldInfo list |
| field_infos: list[FieldInfo] = [] |
| visitor = StructFieldSerializerVisitor(fory) |
| global_ref_tracking = fory.ref_tracking |
| |
| for index, (field_name, dc_field) in enumerate(active_fields): |
| meta = field_metas[field_name] |
| type_hint = type_hints.get(field_name, typing.Any) |
| unwrapped_type, is_optional = unwrap_optional(type_hint) |
| |
| # Compute effective nullable based on mode |
| if xlang: |
| # For xlang: respect explicit annotation or default to is_optional only |
| effective_nullable = meta.nullable or is_optional |
| else: |
| # For native: Optional[T] or non-primitive types are nullable |
| effective_nullable = meta.nullable or is_optional or not is_primitive_type(unwrapped_type) |
| |
| # Compute runtime ref tracking: field.ref AND global config |
| runtime_ref = meta.ref and global_ref_tracking |
| |
| # Infer serializer |
| serializer = infer_field(field_name, unwrapped_type, visitor, types_path=[]) |
| |
| # Get type_id from serializer |
| if serializer is not None: |
| type_id = fory.type_resolver.get_typeinfo(serializer.type_).type_id & 0xFF |
| else: |
| type_id = TypeId.UNKNOWN |
| |
| field_info = FieldInfo( |
| name=field_name, |
| index=index, |
| type_hint=type_hint, |
| tag_id=meta.id, |
| nullable=effective_nullable, |
| ref=meta.ref, |
| runtime_ref_tracking=runtime_ref, |
| type_id=type_id, |
| serializer=serializer, |
| unwrapped_type=unwrapped_type, |
| ) |
| field_infos.append(field_info) |
| |
| return field_infos, field_metas |
| |
| |
| _jit_context = locals() |
| |
| |
| _ENABLE_FORY_PYTHON_JIT = os.environ.get("ENABLE_FORY_PYTHON_JIT", "True").lower() in ( |
| "true", |
| "1", |
| ) |
| |
| |
| class DataClassSerializer(Serializer): |
| def __init__( |
| self, |
| fory, |
| clz: type, |
| xlang: bool = False, |
| field_names: List[str] = None, |
| serializers: List[Serializer] = None, |
| nullable_fields: Dict[str, bool] = None, |
| ): |
| super().__init__(fory, clz) |
| self._xlang = xlang |
| |
| self._type_hints = typing.get_type_hints(clz) |
| self._has_slots = hasattr(clz, "__slots__") |
| |
| # When field_names is explicitly passed (from TypeDef.create_serializer during schema evolution), |
| # use those fields instead of extracting from the class. This is critical for schema evolution |
| # where the sender's schema (in TypeDef) differs from the receiver's registered class. |
| # Track whether field order comes from wire (TypeDef) - don't re-sort these |
| self._fields_from_typedef = field_names is not None and serializers is not None |
| if self._fields_from_typedef: |
| # Use the passed-in field_names and serializers from TypeDef |
| self._field_names = field_names |
| self._serializers = serializers |
| self._nullable_fields = nullable_fields or {} |
| self._ref_fields = {} |
| self._field_infos = [] |
| self._field_metas = {} |
| else: |
| # Extract field infos using new pyfory.field() metadata |
| # Pass xlang to get correct nullable defaults for the mode |
| self._field_infos, self._field_metas = _extract_field_infos(fory, clz, self._type_hints, xlang=xlang) |
| |
| if self._field_infos: |
| # Use new field info based approach |
| self._field_names = [fi.name for fi in self._field_infos] |
| self._serializers = [fi.serializer for fi in self._field_infos] |
| self._nullable_fields = {fi.name: fi.nullable for fi in self._field_infos} |
| self._ref_fields = {fi.name: fi.runtime_ref_tracking for fi in self._field_infos} |
| else: |
| # Fallback for non-dataclass types |
| self._field_names = field_names or self._get_field_names(clz) |
| self._nullable_fields = nullable_fields or {} |
| self._ref_fields = {} |
| |
| if self._field_names and not self._nullable_fields: |
| for field_name in self._field_names: |
| if field_name in self._type_hints: |
| unwrapped_type, is_optional = unwrap_optional(self._type_hints[field_name]) |
| is_nullable = is_optional or not is_primitive_type(unwrapped_type) |
| self._nullable_fields[field_name] = is_nullable |
| |
| self._serializers = serializers or [None] * len(self._field_names) |
| if serializers is None: |
| visitor = StructFieldSerializerVisitor(fory) |
| for index, key in enumerate(self._field_names): |
| unwrapped_type, _ = unwrap_optional(self._type_hints.get(key, typing.Any)) |
| serializer = infer_field(key, unwrapped_type, visitor, types_path=[]) |
| self._serializers[index] = serializer |
| |
| # Cache unwrapped type hints |
| self._unwrapped_hints = self._compute_unwrapped_hints() |
| |
| if self._xlang: |
| # In xlang mode, compute struct meta for hash and field sorting |
| # BUT if fields come from TypeDef (wire data), preserve their order for deserialization |
| if self._fields_from_typedef: |
| # Fields from wire - only compute hash, don't re-sort |
| # The sender already sorted the fields, we must use their order for correct deserialization |
| hash_str = compute_struct_fingerprint( |
| fory.type_resolver, self._field_names, self._serializers, self._nullable_fields, self._field_infos |
| ) |
| hash_bytes = hash_str.encode("utf-8") |
| if len(hash_bytes) == 0: |
| self._hash = 47 |
| else: |
| from pyfory.lib.mmh3 import hash_buffer |
| |
| full_hash = hash_buffer(hash_bytes, seed=47)[0] |
| type_hash_32 = full_hash & 0xFFFFFFFF |
| if full_hash & 0x80000000: |
| type_hash_32 = type_hash_32 - 0x100000000 |
| self._hash = type_hash_32 |
| else: |
| # Fields extracted locally - sort them for consistent serialization |
| self._hash, self._field_names, self._serializers = compute_struct_meta( |
| fory.type_resolver, self._field_names, self._serializers, self._nullable_fields, self._field_infos |
| ) |
| self._generated_xwrite_method = self._gen_xwrite_method() |
| self._generated_xread_method = self._gen_xread_method() |
| if _ENABLE_FORY_PYTHON_JIT: |
| self.xwrite = self._generated_xwrite_method |
| self.xread = self._generated_xread_method |
| if self.fory.is_py: |
| logger.warning( |
| "Type of class %s shouldn't be serialized using cross-language serializer", |
| clz, |
| ) |
| else: |
| # In non-xlang mode, only sort fields in non-compatible mode |
| # In compatible mode, maintain stable field ordering for schema evolution |
| if not fory.compatible: |
| self._hash, self._field_names, self._serializers = compute_struct_meta( |
| fory.type_resolver, self._field_names, self._serializers, self._nullable_fields, self._field_infos |
| ) |
| self._generated_write_method = self._gen_write_method() |
| self._generated_read_method = self._gen_read_method() |
| if _ENABLE_FORY_PYTHON_JIT: |
| self.write = self._generated_write_method |
| self.read = self._generated_read_method |
| |
| def _get_field_names(self, clz): |
| if hasattr(clz, "__dict__"): |
| # Regular object with __dict__ |
| # For dataclasses, preserve field definition order |
| # In compatible mode, stable field ordering is critical for schema evolution |
| if dataclasses.is_dataclass(clz): |
| # Use dataclasses.fields() to get fields in definition order |
| return [field.name for field in dataclasses.fields(clz)] |
| # For non-dataclass objects, sort by key names for consistency |
| return sorted(self._type_hints.keys()) |
| elif hasattr(clz, "__slots__"): |
| # Object with __slots__ |
| return sorted(clz.__slots__) |
| return [] |
| |
| def _compute_unwrapped_hints(self): |
| """Compute unwrapped type hints once and cache.""" |
| from pyfory.type import unwrap_optional |
| |
| return {field_name: unwrap_optional(hint)[0] for field_name, hint in self._type_hints.items()} |
| |
| def _write_header(self, buffer): |
| """Write serialization header (hash or field count based on compatible mode).""" |
| if not self.fory.compatible: |
| buffer.write_int32(self._hash) |
| else: |
| buffer.write_varuint32(len(self._field_names)) |
| |
| def _read_header(self, buffer): |
| """Read serialization header and return number of fields written. |
| |
| Returns: |
| int: Number of fields that were written |
| |
| Raises: |
| TypeNotCompatibleError: If hash doesn't match in non-compatible mode |
| """ |
| if not self.fory.compatible: |
| hash_ = buffer.read_int32() |
| expected_hash = self._hash |
| if hash_ != expected_hash: |
| raise TypeNotCompatibleError(f"Hash {hash_} is not consistent with {expected_hash} for type {self.type_}") |
| return len(self._field_names) |
| else: |
| return buffer.read_varuint32() |
| |
| def _get_write_stmt_for_codegen(self, serializer, buffer, field_value): |
| """Generate write statement for code generation based on serializer type.""" |
| if isinstance(serializer, BooleanSerializer): |
| return f"{buffer}.write_bool({field_value})" |
| elif isinstance(serializer, ByteSerializer): |
| return f"{buffer}.write_int8({field_value})" |
| elif isinstance(serializer, Int16Serializer): |
| return f"{buffer}.write_int16({field_value})" |
| elif isinstance(serializer, Int32Serializer): |
| return f"{buffer}.write_varint32({field_value})" |
| elif isinstance(serializer, Int64Serializer): |
| return f"{buffer}.write_varint64({field_value})" |
| elif isinstance(serializer, Float32Serializer): |
| return f"{buffer}.write_float32({field_value})" |
| elif isinstance(serializer, Float64Serializer): |
| return f"{buffer}.write_float64({field_value})" |
| elif isinstance(serializer, StringSerializer): |
| return f"{buffer}.write_string({field_value})" |
| else: |
| return None # Complex type, needs ref handling |
| |
| def _get_read_stmt_for_codegen(self, serializer, buffer, field_value): |
| """Generate read statement for code generation based on serializer type.""" |
| if isinstance(serializer, BooleanSerializer): |
| return f"{field_value} = {buffer}.read_bool()" |
| elif isinstance(serializer, ByteSerializer): |
| return f"{field_value} = {buffer}.read_int8()" |
| elif isinstance(serializer, Int16Serializer): |
| return f"{field_value} = {buffer}.read_int16()" |
| elif isinstance(serializer, Int32Serializer): |
| return f"{field_value} = {buffer}.read_varint32()" |
| elif isinstance(serializer, Int64Serializer): |
| return f"{field_value} = {buffer}.read_varint64()" |
| elif isinstance(serializer, Float32Serializer): |
| return f"{field_value} = {buffer}.read_float32()" |
| elif isinstance(serializer, Float64Serializer): |
| return f"{field_value} = {buffer}.read_float64()" |
| elif isinstance(serializer, StringSerializer): |
| return f"{field_value} = {buffer}.read_string()" |
| else: |
| return None # Complex type, needs ref handling |
| |
| def _write_non_nullable_field(self, buffer, field_value, serializer): |
| """Write a non-nullable field value at runtime.""" |
| if isinstance(serializer, BooleanSerializer): |
| buffer.write_bool(field_value) |
| elif isinstance(serializer, ByteSerializer): |
| buffer.write_int8(field_value) |
| elif isinstance(serializer, Int16Serializer): |
| buffer.write_int16(field_value) |
| elif isinstance(serializer, Int32Serializer): |
| buffer.write_varint32(field_value) |
| elif isinstance(serializer, Int64Serializer): |
| buffer.write_varint64(field_value) |
| elif isinstance(serializer, Float32Serializer): |
| buffer.write_float32(field_value) |
| elif isinstance(serializer, Float64Serializer): |
| buffer.write_float64(field_value) |
| elif isinstance(serializer, StringSerializer): |
| buffer.write_string(field_value) |
| else: |
| self.fory.write_ref_pyobject(buffer, field_value) |
| |
| def _read_non_nullable_field(self, buffer, serializer): |
| """Read a non-nullable field value at runtime.""" |
| if isinstance(serializer, BooleanSerializer): |
| return buffer.read_bool() |
| elif isinstance(serializer, ByteSerializer): |
| return buffer.read_int8() |
| elif isinstance(serializer, Int16Serializer): |
| return buffer.read_int16() |
| elif isinstance(serializer, Int32Serializer): |
| return buffer.read_varint32() |
| elif isinstance(serializer, Int64Serializer): |
| return buffer.read_varint64() |
| elif isinstance(serializer, Float32Serializer): |
| return buffer.read_float32() |
| elif isinstance(serializer, Float64Serializer): |
| return buffer.read_float64() |
| elif isinstance(serializer, StringSerializer): |
| return buffer.read_string() |
| else: |
| return self.fory.read_ref_pyobject(buffer) |
| |
| def _write_nullable_field(self, buffer, field_value, serializer): |
| """Write a nullable field value at runtime.""" |
| if field_value is None: |
| buffer.write_int8(NULL_FLAG) |
| else: |
| buffer.write_int8(NOT_NULL_VALUE_FLAG) |
| if isinstance(serializer, StringSerializer): |
| buffer.write_string(field_value) |
| else: |
| self.fory.write_ref_pyobject(buffer, field_value) |
| |
| def _read_nullable_field(self, buffer, serializer): |
| """Read a nullable field value at runtime.""" |
| flag = buffer.read_int8() |
| if flag == NULL_FLAG: |
| return None |
| else: |
| if isinstance(serializer, StringSerializer): |
| return buffer.read_string() |
| else: |
| return self.fory.read_ref_pyobject(buffer) |
| |
| def _gen_write_method(self): |
| context = {} |
| counter = itertools.count(0) |
| buffer, fory, value, value_dict = "buffer", "fory", "value", "value_dict" |
| context[fory] = self.fory |
| context["_serializers"] = self._serializers |
| |
| stmts = [ |
| f'"""write method for {self.type_}"""', |
| ] |
| |
| # Write hash only in non-compatible mode; in compatible mode, write field count |
| if not self.fory.compatible: |
| stmts.append(f"{buffer}.write_int32({self._hash})") |
| else: |
| stmts.append(f"{buffer}.write_varuint32({len(self._field_names)})") |
| |
| if not self._has_slots: |
| stmts.append(f"{value_dict} = {value}.__dict__") |
| |
| # Write field values in order |
| for index, field_name in enumerate(self._field_names): |
| field_value = f"field_value{next(counter)}" |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| |
| if not self._has_slots: |
| stmts.append(f"{field_value} = {value_dict}['{field_name}']") |
| else: |
| stmts.append(f"{field_value} = {value}.{field_name}") |
| |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| # Use gen_write_nullable_basic_stmts for nullable basic types |
| if isinstance(serializer, BooleanSerializer): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, bool)) |
| elif isinstance(serializer, (ByteSerializer, Int16Serializer, Int32Serializer, Int64Serializer)): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, int)) |
| elif isinstance(serializer, (Float32Serializer, Float64Serializer)): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, float)) |
| elif isinstance(serializer, StringSerializer): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, str)) |
| else: |
| # For complex types, use write_ref_pyobject |
| stmts.append(f"{fory}.write_ref_pyobject({buffer}, {field_value})") |
| else: |
| stmt = self._get_write_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| stmt = f"{fory}.write_ref_pyobject({buffer}, {field_value})" |
| stmts.append(stmt) |
| |
| self._write_method_code, func = compile_function( |
| f"write_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer, value], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def _gen_read_method(self): |
| context = dict(_jit_context) |
| buffer, fory, obj_class, obj, obj_dict = ( |
| "buffer", |
| "fory", |
| "obj_class", |
| "obj", |
| "obj_dict", |
| ) |
| ref_resolver = "ref_resolver" |
| context[fory] = self.fory |
| context[obj_class] = self.type_ |
| context[ref_resolver] = self.fory.ref_resolver |
| context["_serializers"] = self._serializers |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| |
| stmts = [ |
| f'"""read method for {self.type_}"""', |
| ] |
| if not self.fory.strict: |
| context["checker"] = self.fory.policy |
| stmts.append(f"checker.authorize_instantiation({obj_class})") |
| |
| # Read hash only in non-compatible mode; in compatible mode, read field count |
| if not self.fory.compatible: |
| stmts.extend( |
| [ |
| f"read_hash = {buffer}.read_int32()", |
| f"if read_hash != {self._hash}:", |
| f""" raise TypeNotCompatibleError( |
| f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", |
| ] |
| ) |
| else: |
| stmts.append(f"num_fields_written = {buffer}.read_varuint32()") |
| |
| stmts.extend( |
| [ |
| f"{obj} = {obj_class}.__new__({obj_class})", |
| f"{ref_resolver}.reference({obj})", |
| ] |
| ) |
| |
| if not self._has_slots: |
| stmts.append(f"{obj_dict} = {obj}.__dict__") |
| |
| # Read field values in order |
| for index, field_name in enumerate(self._field_names): |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| field_value = f"field_value{index}" |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| # Build field reading statements |
| field_stmts = [] |
| |
| if is_nullable: |
| # Use gen_read_nullable_basic_stmts for nullable basic types |
| if isinstance(serializer, BooleanSerializer): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, bool, lambda v: f"{field_value} = {v}")) |
| elif isinstance(serializer, (ByteSerializer, Int16Serializer, Int32Serializer, Int64Serializer)): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, int, lambda v: f"{field_value} = {v}")) |
| elif isinstance(serializer, (Float32Serializer, Float64Serializer)): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, float, lambda v: f"{field_value} = {v}")) |
| elif isinstance(serializer, StringSerializer): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, str, lambda v: f"{field_value} = {v}")) |
| else: |
| # For complex types, use read_ref_pyobject |
| field_stmts.append(f"{field_value} = {fory}.read_ref_pyobject({buffer})") |
| else: |
| stmt = self._get_read_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| stmt = f"{field_value} = {fory}.read_ref_pyobject({buffer})" |
| field_stmts.append(stmt) |
| |
| # Set field value if it exists in current class |
| if field_name not in current_class_field_names: |
| field_stmts.append(f"# {field_name} is not in {self.type_}") |
| else: |
| if not self._has_slots: |
| field_stmts.append(f"{obj_dict}['{field_name}'] = {field_value}") |
| else: |
| field_stmts.append(f"{obj}.{field_name} = {field_value}") |
| |
| # In compatible mode, wrap field reading in a check |
| if self.fory.compatible: |
| stmts.append(f"if {index} < num_fields_written:") |
| # Indent all field statements |
| from pyfory.codegen import ident_lines |
| |
| field_stmts = ident_lines(field_stmts) |
| stmts.extend(field_stmts) |
| else: |
| stmts.extend(field_stmts) |
| |
| stmts.append(f"return {obj}") |
| self._read_method_code, func = compile_function( |
| f"read_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def _gen_xwrite_method(self): |
| """Generate JIT-compiled xwrite method. |
| |
| Per xlang spec, struct format is: |
| - Schema consistent mode: |4-byte hash|field values| |
| - Schema evolution mode (compatible): |field values| (no field count prefix!) |
| The field count is in TypeDef meta written at the end, not in object data. |
| """ |
| context = {} |
| counter = itertools.count(0) |
| buffer, fory, value, value_dict = "buffer", "fory", "value", "value_dict" |
| context[fory] = self.fory |
| context["_serializers"] = self._serializers |
| stmts = [ |
| f'"""xwrite method for {self.type_}"""', |
| ] |
| if not self.fory.compatible: |
| stmts.append(f"{buffer}.write_int32({self._hash})") |
| if not self._has_slots: |
| stmts.append(f"{value_dict} = {value}.__dict__") |
| for index, field_name in enumerate(self._field_names): |
| field_value = f"field_value{next(counter)}" |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| is_nullable = self._nullable_fields.get(field_name, False) |
| # For schema evolution: use safe access with None default to handle |
| # cases where the field might not exist on the object (missing from remote schema) |
| # In compatible mode, always use safe access even for non-nullable fields |
| if not self._has_slots: |
| if is_nullable or self.fory.compatible: |
| stmts.append(f"{field_value} = {value_dict}.get('{field_name}')") |
| else: |
| stmts.append(f"{field_value} = {value_dict}['{field_name}']") |
| else: |
| if is_nullable or self.fory.compatible: |
| stmts.append(f"{field_value} = getattr({value}, '{field_name}', None)") |
| else: |
| stmts.append(f"{field_value} = {value}.{field_name}") |
| if is_nullable: |
| if isinstance(serializer, StringSerializer): |
| stmts.extend( |
| [ |
| f"if {field_value} is None:", |
| f" {buffer}.write_int8({NULL_FLAG})", |
| "else:", |
| f" {buffer}.write_int8({NOT_NULL_VALUE_FLAG})", |
| f" {buffer}.write_string({field_value})", |
| ] |
| ) |
| else: |
| stmts.append(f"{fory}.xwrite_ref({buffer}, {field_value}, serializer={serializer_var})") |
| else: |
| stmt = self._get_write_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| # For non-nullable complex types, use xwrite_no_ref |
| stmt = f"{fory}.xwrite_no_ref({buffer}, {field_value}, serializer={serializer_var})" |
| # In compatible mode, handle None for non-nullable fields (schema evolution) |
| # Write zero/default value when field is None due to missing from remote schema |
| if self.fory.compatible: |
| from pyfory.serializer import EnumSerializer |
| |
| if isinstance(serializer, EnumSerializer): |
| # For enums, write ordinal 0 when None |
| stmts.extend( |
| [ |
| f"if {field_value} is None:", |
| f" {buffer}.write_varuint32(0)", |
| "else:", |
| f" {stmt}", |
| ] |
| ) |
| else: |
| stmts.append(stmt) |
| else: |
| stmts.append(stmt) |
| self._xwrite_method_code, func = compile_function( |
| f"xwrite_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer, value], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def _gen_xread_method(self): |
| """Generate JIT-compiled xread method. |
| |
| Per xlang spec, struct format is: |
| - Schema consistent mode: |4-byte hash|field values| |
| - Schema evolution mode (compatible): |field values| (no field count prefix!) |
| The field count is in TypeDef meta written at the end, not in object data. |
| """ |
| context = dict(_jit_context) |
| buffer, fory, obj_class, obj, obj_dict = ( |
| "buffer", |
| "fory", |
| "obj_class", |
| "obj", |
| "obj_dict", |
| ) |
| ref_resolver = "ref_resolver" |
| context[fory] = self.fory |
| context[obj_class] = self.type_ |
| context[ref_resolver] = self.fory.ref_resolver |
| context["_serializers"] = self._serializers |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| stmts = [ |
| f'"""xread method for {self.type_}"""', |
| ] |
| if not self.fory.strict: |
| context["checker"] = self.fory.policy |
| stmts.append(f"checker.authorize_instantiation({obj_class})") |
| if not self.fory.compatible: |
| stmts.extend( |
| [ |
| f"read_hash = {buffer}.read_int32()", |
| f"if read_hash != {self._hash}:", |
| f""" raise TypeNotCompatibleError( |
| f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", |
| ] |
| ) |
| stmts.extend( |
| [ |
| f"{obj} = {obj_class}.__new__({obj_class})", |
| f"{ref_resolver}.reference({obj})", |
| ] |
| ) |
| |
| if not self._has_slots: |
| stmts.append(f"{obj_dict} = {obj}.__dict__") |
| |
| for index, field_name in enumerate(self._field_names): |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| field_value = f"field_value{index}" |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| if is_nullable: |
| if isinstance(serializer, StringSerializer): |
| stmts.extend( |
| [ |
| f"if {buffer}.read_int8() >= {NOT_NULL_VALUE_FLAG}:", |
| f" {field_value} = {buffer}.read_string()", |
| "else:", |
| f" {field_value} = None", |
| ] |
| ) |
| else: |
| stmts.append(f"{field_value} = {fory}.xread_ref({buffer}, serializer={serializer_var})") |
| else: |
| stmt = self._get_read_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| # For non-nullable complex types, use xread_no_ref |
| stmt = f"{field_value} = {fory}.xread_no_ref({buffer}, serializer={serializer_var})" |
| stmts.append(stmt) |
| |
| if field_name not in current_class_field_names: |
| stmts.append(f"# {field_name} is not in {self.type_}") |
| elif not self._has_slots: |
| stmts.append(f"{obj_dict}['{field_name}'] = {field_value}") |
| else: |
| stmts.append(f"{obj}.{field_name} = {field_value}") |
| |
| # For schema evolution: initialize missing fields with default values |
| # This handles cases where the sender's schema has fewer fields than the receiver's |
| if self.fory.compatible: |
| read_field_names = set(self._field_names) |
| missing_fields = current_class_field_names - read_field_names |
| if missing_fields and dataclasses.is_dataclass(self.type_): |
| for dc_field in dataclasses.fields(self.type_): |
| if dc_field.name in missing_fields: |
| if dc_field.default is not dataclasses.MISSING: |
| default_val = repr(dc_field.default) |
| if not self._has_slots: |
| stmts.append(f"{obj_dict}['{dc_field.name}'] = {default_val}") |
| else: |
| stmts.append(f"{obj}.{dc_field.name} = {default_val}") |
| elif dc_field.default_factory is not dataclasses.MISSING: |
| factory_var = f"_default_factory_{dc_field.name}" |
| context[factory_var] = dc_field.default_factory |
| if not self._has_slots: |
| stmts.append(f"{obj_dict}['{dc_field.name}'] = {factory_var}()") |
| else: |
| stmts.append(f"{obj}.{dc_field.name} = {factory_var}()") |
| # else: field has no default, leave it unset |
| |
| stmts.append(f"return {obj}") |
| self._xread_method_code, func = compile_function( |
| f"xread_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def write(self, buffer, value): |
| """Write dataclass instance to buffer in Python native format.""" |
| self._write_header(buffer) |
| |
| for index, field_name in enumerate(self._field_names): |
| field_value = getattr(value, field_name) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| if is_nullable: |
| self._write_nullable_field(buffer, field_value, serializer) |
| else: |
| self._write_non_nullable_field(buffer, field_value, serializer) |
| |
| def read(self, buffer): |
| """Read dataclass instance from buffer in Python native format.""" |
| num_fields_written = self._read_header(buffer) |
| |
| obj = self.type_.__new__(self.type_) |
| self.fory.ref_resolver.reference(obj) |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| |
| for index, field_name in enumerate(self._field_names): |
| # Only read if this field was written |
| if index >= num_fields_written: |
| break |
| |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| if is_nullable: |
| field_value = self._read_nullable_field(buffer, serializer) |
| else: |
| field_value = self._read_non_nullable_field(buffer, serializer) |
| |
| if field_name in current_class_field_names: |
| setattr(obj, field_name, field_value) |
| return obj |
| |
| def xwrite(self, buffer: Buffer, value): |
| """Write dataclass instance to buffer in cross-language format. |
| |
| Per xlang spec, struct format is: |
| - Schema consistent mode: |4-byte hash|field values| |
| - Schema evolution mode (compatible): |field values| (no field count prefix!) |
| The field count is in TypeDef meta written at the end, not in object data. |
| """ |
| if not self._xlang: |
| raise TypeError("xwrite can only be called when DataClassSerializer is in xlang mode") |
| if not self.fory.compatible: |
| buffer.write_int32(self._hash) |
| for index, field_name in enumerate(self._field_names): |
| field_value = getattr(value, field_name) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| if field_value is None: |
| buffer.write_int8(-3) |
| else: |
| self.fory.xwrite_ref(buffer, field_value, serializer=serializer) |
| else: |
| serializer.xwrite(buffer, field_value) |
| |
| def xread(self, buffer): |
| """Read dataclass instance from buffer in cross-language format. |
| |
| Per xlang spec, struct format is: |
| - Schema consistent mode: |4-byte hash|field values| |
| - Schema evolution mode (compatible): |field values| (no field count prefix!) |
| The field count is in TypeDef meta written at the end, not in object data. |
| """ |
| if not self._xlang: |
| raise TypeError("xread can only be called when DataClassSerializer is in xlang mode") |
| if not self.fory.compatible: |
| hash_ = buffer.read_int32() |
| if hash_ != self._hash: |
| raise TypeNotCompatibleError( |
| f"Hash {hash_} is not consistent with {self._hash} for type {self.type_}", |
| ) |
| obj = self.type_.__new__(self.type_) |
| self.fory.ref_resolver.reference(obj) |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| read_field_names = set() |
| for index, field_name in enumerate(self._field_names): |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| ref_id = buffer.read_int8() |
| if ref_id == -3: |
| field_value = None |
| else: |
| buffer.reader_index -= 1 |
| field_value = self.fory.xread_ref(buffer, serializer=serializer) |
| else: |
| field_value = serializer.xread(buffer) |
| if field_name in current_class_field_names: |
| setattr(obj, field_name, field_value) |
| read_field_names.add(field_name) |
| # For schema evolution: initialize missing fields with default values |
| # This handles cases where the sender's schema has fewer fields than the receiver's |
| if self.fory.compatible: |
| missing_fields = current_class_field_names - read_field_names |
| if missing_fields and dataclasses.is_dataclass(self.type_): |
| for dc_field in dataclasses.fields(self.type_): |
| if dc_field.name in missing_fields: |
| if dc_field.default is not dataclasses.MISSING: |
| setattr(obj, dc_field.name, dc_field.default) |
| elif dc_field.default_factory is not dataclasses.MISSING: |
| setattr(obj, dc_field.name, dc_field.default_factory()) |
| # else: field has no default, leave it unset (will be None for nullable) |
| return obj |
| |
| |
| class DataClassStubSerializer(DataClassSerializer): |
| def __init__(self, fory, clz: type, xlang: bool = False): |
| Serializer.__init__(self, fory, clz) |
| self.xlang = xlang |
| |
| def write(self, buffer, value): |
| self._replace().write(buffer, value) |
| |
| def read(self, buffer): |
| return self._replace().read(buffer) |
| |
| def xwrite(self, buffer, value): |
| self._replace().xwrite(buffer, value) |
| |
| def xread(self, buffer): |
| return self._replace().xread(buffer) |
| |
| def _replace(self): |
| typeinfo = self.fory.type_resolver.get_typeinfo(self.type_) |
| typeinfo.serializer = DataClassSerializer(self.fory, self.type_, self.xlang) |
| return typeinfo.serializer |
| |
| |
| basic_types = { |
| bool, |
| int8, |
| int16, |
| int32, |
| int64, |
| float32, |
| float64, |
| int, |
| float, |
| str, |
| bytes, |
| datetime.datetime, |
| datetime.date, |
| datetime.time, |
| } |
| |
| |
| class StructFieldSerializerVisitor(TypeVisitor): |
| def __init__( |
| self, |
| fory, |
| ): |
| self.fory = fory |
| |
| def visit_list(self, field_name, elem_type, types_path=None): |
| from pyfory.serializer import ListSerializer # Local import |
| |
| # Infer type recursively for type such as List[Dict[str, str]] |
| elem_serializer = infer_field("item", elem_type, self, types_path=types_path) |
| return ListSerializer(self.fory, list, elem_serializer) |
| |
| def visit_set(self, field_name, elem_type, types_path=None): |
| from pyfory.serializer import SetSerializer # Local import |
| |
| # Infer type recursively for type such as Set[Dict[str, str]] |
| elem_serializer = infer_field("item", elem_type, self, types_path=types_path) |
| return SetSerializer(self.fory, set, elem_serializer) |
| |
| def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| from pyfory.serializer import MapSerializer # Local import |
| |
| # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| key_serializer = infer_field("key", key_type, self, types_path=types_path) |
| value_serializer = infer_field("value", value_type, self, types_path=types_path) |
| return MapSerializer(self.fory, dict, key_serializer, value_serializer) |
| |
| def visit_customized(self, field_name, type_, types_path=None): |
| if issubclass(type_, enum.Enum): |
| return self.fory.type_resolver.get_serializer(type_) |
| # For custom types (dataclasses, etc.), try to get or create serializer |
| # This enables field-level serializer resolution for types like inner structs |
| typeinfo = self.fory.type_resolver.get_typeinfo(type_, create=False) |
| if typeinfo is not None: |
| return typeinfo.serializer |
| return None |
| |
| def visit_other(self, field_name, type_, types_path=None): |
| if is_subclass(type_, enum.Enum): |
| return self.fory.type_resolver.get_serializer(type_) |
| if type_ not in basic_types and not is_py_array_type(type_): |
| return None |
| serializer = self.fory.type_resolver.get_serializer(type_) |
| return serializer |
| |
| |
| _UNKNOWN_TYPE_ID = -1 |
| _time_types = {datetime.date, datetime.datetime, datetime.timedelta} |
| |
| |
| def _sort_fields(type_resolver, field_names, serializers, nullable_map=None): |
| (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) = group_fields( |
| type_resolver, field_names, serializers, nullable_map |
| ) |
| all_types = boxed_types + nullable_boxed_types + internal_types + collection_types + set_types + map_types + other_types |
| return [t[2] for t in all_types], [t[1] for t in all_types] |
| |
| |
| def group_fields(type_resolver, field_names, serializers, nullable_map=None): |
| nullable_map = nullable_map or {} |
| boxed_types = [] |
| nullable_boxed_types = [] |
| collection_types = [] |
| set_types = [] |
| map_types = [] |
| internal_types = [] |
| other_types = [] |
| type_ids = [] |
| for field_name, serializer in zip(field_names, serializers): |
| if serializer is None: |
| other_types.append((_UNKNOWN_TYPE_ID, serializer, field_name)) |
| else: |
| type_ids.append( |
| ( |
| type_resolver.get_typeinfo(serializer.type_).type_id & 0xFF, |
| serializer, |
| field_name, |
| ) |
| ) |
| for type_id, serializer, field_name in type_ids: |
| is_nullable = nullable_map.get(field_name, False) |
| if is_primitive_type(type_id): |
| container = nullable_boxed_types if is_nullable else boxed_types |
| elif type_id == TypeId.SET: |
| container = set_types |
| elif is_list_type(serializer.type_): |
| container = collection_types |
| elif is_map_type(serializer.type_): |
| container = map_types |
| elif is_polymorphic_type(type_id) or type_id in { |
| TypeId.ENUM, |
| TypeId.NAMED_ENUM, |
| }: |
| container = other_types |
| else: |
| assert TypeId.UNKNOWN < type_id < TypeId.BOUND, (type_id,) |
| container = internal_types |
| container.append((type_id, serializer, field_name)) |
| |
| def sorter(item): |
| return item[0], item[2] |
| |
| def numeric_sorter(item): |
| id_ = item[0] |
| compress = id_ in { |
| TypeId.INT32, |
| TypeId.INT64, |
| TypeId.VAR_INT32, |
| TypeId.VAR_INT64, |
| } |
| # Sort by: compress flag, -size (largest first), -type_id (higher type ID first), field_name |
| # Java sorts by size (largest first), then by primitive type ID (descending) |
| return int(compress), -get_primitive_type_size(id_), -id_, item[2] |
| |
| boxed_types = sorted(boxed_types, key=numeric_sorter) |
| nullable_boxed_types = sorted(nullable_boxed_types, key=numeric_sorter) |
| collection_types = sorted(collection_types, key=sorter) |
| set_types = sorted(set_types, key=sorter) |
| internal_types = sorted(internal_types, key=sorter) |
| map_types = sorted(map_types, key=sorter) |
| other_types = sorted(other_types, key=lambda item: item[2]) |
| return (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) |
| |
| |
| def compute_struct_fingerprint(type_resolver, field_names, serializers, nullable_map=None, field_infos_list=None): |
| """ |
| Computes the fingerprint string for a struct type used in schema versioning. |
| |
| Fingerprint Format: |
| Each field contributes: <field_id_or_name>,<type_id>,<ref>,<nullable>; |
| Fields are sorted by tag ID (if >=0) or field name (if id=-1). |
| |
| Field Components: |
| - field_id_or_name: Tag ID as string if id >= 0, otherwise field name |
| - type_id: Fory TypeId as decimal string (e.g., "4" for INT32) |
| - ref: "1" if field has ref=True in pyfory.field(), "0" otherwise |
| (based on field annotation, NOT runtime config) |
| - nullable: "1" if null flag is written, "0" otherwise |
| |
| Example fingerprints: |
| With tag IDs: "0,4,0,0;1,12,0,1;2,0,0,1;" |
| With field names: "age,4,0,0;email,12,0,1;name,9,0,0;" |
| |
| This format is consistent across Go, Java, Rust, C++, and Python implementations. |
| """ |
| if nullable_map is None: |
| nullable_map = {} |
| |
| # Build field info list for fingerprint: (sort_key, field_id_or_name, type_id, ref_flag, nullable_flag) |
| fp_fields = [] |
| |
| # Build a lookup for field_infos by name if available |
| field_info_map = {} |
| if field_infos_list: |
| field_info_map = {fi.name: fi for fi in field_infos_list} |
| |
| for i, field_name in enumerate(field_names): |
| serializer = serializers[i] |
| |
| # Get field metadata if available |
| fi = field_info_map.get(field_name) |
| tag_id = fi.tag_id if fi else -1 |
| ref_flag = "1" if (fi and fi.ref) else "0" |
| |
| if serializer is None: |
| type_id = TypeId.UNKNOWN |
| # For unknown serializers, use nullable from map (defaults to False for xlang) |
| nullable_flag = "1" if nullable_map.get(field_name, False) else "0" |
| else: |
| type_id = type_resolver.get_typeinfo(serializer.type_).type_id & 0xFF |
| is_nullable = nullable_map.get(field_name, False) |
| |
| # For polymorphic or enum types, set type_id to UNKNOWN but preserve nullable from map |
| if is_polymorphic_type(type_id) or type_id in {TypeId.ENUM, TypeId.NAMED_ENUM}: |
| type_id = TypeId.UNKNOWN |
| |
| # Use nullable from map - for xlang, this is already computed correctly |
| # (False by default except for Optional[T] or explicit annotation) |
| nullable_flag = "1" if is_nullable else "0" |
| |
| # Determine field identifier for fingerprint |
| if tag_id >= 0: |
| field_id_or_name = str(tag_id) |
| # Sort by tag ID (numeric) for tag ID fields |
| sort_key = (0, tag_id, "") # 0 = tag ID fields come first |
| else: |
| field_id_or_name = field_name |
| # Sort by field name (lexicographic) for name-based fields |
| sort_key = (1, 0, field_name) # 1 = name fields come after |
| |
| fp_fields.append((sort_key, field_id_or_name, type_id, ref_flag, nullable_flag)) |
| |
| # Sort fields: tag ID fields first (by ID), then name fields (lexicographically) |
| fp_fields.sort(key=lambda x: x[0]) |
| |
| # Build fingerprint string |
| hash_parts = [] |
| for _, field_id_or_name, type_id, ref_flag, nullable_flag in fp_fields: |
| hash_parts.append(f"{field_id_or_name},{type_id},{ref_flag},{nullable_flag};") |
| |
| return "".join(hash_parts) |
| |
| |
| def compute_struct_meta(type_resolver, field_names, serializers, nullable_map=None, field_infos_list=None): |
| """ |
| Computes struct metadata including version hash, sorted field names, and serializers. |
| |
| Uses compute_struct_fingerprint to build the fingerprint string, then hashes it |
| with MurmurHash3 using seed 47, and takes the low 32 bits as signed int32. |
| |
| This provides the cross-language struct version ID used by class version checking, |
| consistent with Go, Java, Rust, and C++ implementations. |
| """ |
| (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) = group_fields( |
| type_resolver, field_names, serializers, nullable_map |
| ) |
| |
| # Compute fingerprint string using the new format with field infos |
| hash_str = compute_struct_fingerprint(type_resolver, field_names, serializers, nullable_map, field_infos_list) |
| hash_bytes = hash_str.encode("utf-8") |
| |
| # Handle empty hash_bytes (no fields or all fields are unknown/dynamic) |
| if len(hash_bytes) == 0: |
| full_hash = 47 # Use seed as default hash for empty structs |
| else: |
| full_hash = hash_buffer(hash_bytes, seed=47)[0] |
| type_hash_32 = full_hash & 0xFFFFFFFF |
| if full_hash & 0x80000000: |
| # If the sign bit is set, it's a negative number in 2's complement |
| # Subtract 2^32 to get the correct negative value |
| type_hash_32 = type_hash_32 - 0x100000000 |
| assert type_hash_32 != 0 |
| if os.environ.get("ENABLE_FORY_DEBUG_OUTPUT", "").lower() in ("1", "true"): |
| print(f'[Python][fory-debug] struct version fingerprint="{hash_str}" version hash={type_hash_32}') |
| |
| # Flatten all groups in correct order (already sorted from group_fields) |
| all_types = boxed_types + nullable_boxed_types + internal_types + collection_types + set_types + map_types + other_types |
| sorted_field_names = [f[2] for f in all_types] |
| sorted_serializers = [f[1] for f in all_types] |
| |
| return type_hash_32, sorted_field_names, sorted_serializers |
| |
| |
| class StructTypeIdVisitor(TypeVisitor): |
| def __init__( |
| self, |
| fory, |
| cls, |
| ): |
| self.fory = fory |
| self.cls = cls |
| |
| def visit_list(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as List[Dict[str, str]] |
| elem_ids = infer_field("item", elem_type, self, types_path=types_path) |
| return TypeId.LIST, elem_ids |
| |
| def visit_set(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as Set[Dict[str, str]] |
| elem_ids = infer_field("item", elem_type, self, types_path=types_path) |
| return TypeId.SET, elem_ids |
| |
| def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| key_ids = infer_field("key", key_type, self, types_path=types_path) |
| value_ids = infer_field("value", value_type, self, types_path=types_path) |
| return TypeId.MAP, key_ids, value_ids |
| |
| def visit_customized(self, field_name, type_, types_path=None): |
| typeinfo = self.fory.type_resolver.get_typeinfo(type_, create=False) |
| if typeinfo is None: |
| return [TypeId.UNKNOWN] |
| return [typeinfo.type_id] |
| |
| def visit_other(self, field_name, type_, types_path=None): |
| if is_subclass(type_, enum.Enum): |
| return [self.fory.type_resolver.get_typeinfo(type_).type_id] |
| if type_ not in basic_types and not is_py_array_type(type_): |
| return None, None |
| typeinfo = self.fory.type_resolver.get_typeinfo(type_) |
| return [typeinfo.type_id] |
| |
| |
| class StructTypeVisitor(TypeVisitor): |
| def __init__(self, cls): |
| self.cls = cls |
| |
| def visit_list(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as List[Dict[str, str]] |
| elem_types = infer_field("item", elem_type, self, types_path=types_path) |
| return typing.List, elem_types |
| |
| def visit_set(self, field_name, elem_type, types_path=None): |
| # Infer type recursively for type such as Set[Dict[str, str]] |
| elem_types = infer_field("item", elem_type, self, types_path=types_path) |
| return typing.Set, elem_types |
| |
| def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| key_types = infer_field("key", key_type, self, types_path=types_path) |
| value_types = infer_field("value", value_type, self, types_path=types_path) |
| return typing.Dict, key_types, value_types |
| |
| def visit_customized(self, field_name, type_, types_path=None): |
| return [type_] |
| |
| def visit_other(self, field_name, type_, types_path=None): |
| return [type_] |
| |
| |
| def get_field_names(clz, type_hints=None): |
| if hasattr(clz, "__dict__"): |
| # Regular object with __dict__ |
| # We can't know the fields without an instance, so we rely on type hints |
| if type_hints is None: |
| type_hints = typing.get_type_hints(clz) |
| return sorted(type_hints.keys()) |
| elif hasattr(clz, "__slots__"): |
| # Object with __slots__ |
| return sorted(clz.__slots__) |
| return [] |