| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import array |
| import builtins |
| import dataclasses |
| import importlib |
| import inspect |
| import itertools |
| import marshal |
| import logging |
| import os |
| import pickle |
| import types |
| import typing |
| from typing import List, Dict |
| |
| from pyfory.buffer import Buffer |
| from pyfory.codegen import ( |
| gen_write_nullable_basic_stmts, |
| gen_read_nullable_basic_stmts, |
| compile_function, |
| ) |
| from pyfory.error import TypeNotCompatibleError |
| from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG |
| from pyfory import Language |
| |
| from pyfory.type import is_primitive_type |
| |
| try: |
| import numpy as np |
| except ImportError: |
| np = None |
| |
| from pyfory._fory import ( |
| NOT_NULL_INT64_FLAG, |
| BufferObject, |
| ) |
| |
| _WINDOWS = os.name == "nt" |
| |
| from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION |
| |
| if ENABLE_FORY_CYTHON_SERIALIZATION: |
| from pyfory.serialization import ( # noqa: F401, F811 |
| Serializer, |
| XlangCompatibleSerializer, |
| BooleanSerializer, |
| ByteSerializer, |
| Int16Serializer, |
| Int32Serializer, |
| Int64Serializer, |
| Float32Serializer, |
| Float64Serializer, |
| StringSerializer, |
| DateSerializer, |
| TimestampSerializer, |
| CollectionSerializer, |
| ListSerializer, |
| TupleSerializer, |
| StringArraySerializer, |
| SetSerializer, |
| MapSerializer, |
| EnumSerializer, |
| SliceSerializer, |
| ) |
| else: |
| from pyfory._serializer import ( # noqa: F401 # pylint: disable=unused-import |
| Serializer, |
| XlangCompatibleSerializer, |
| BooleanSerializer, |
| ByteSerializer, |
| Int16Serializer, |
| Int32Serializer, |
| Int64Serializer, |
| Float32Serializer, |
| Float64Serializer, |
| StringSerializer, |
| DateSerializer, |
| TimestampSerializer, |
| CollectionSerializer, |
| ListSerializer, |
| TupleSerializer, |
| StringArraySerializer, |
| SetSerializer, |
| MapSerializer, |
| EnumSerializer, |
| SliceSerializer, |
| ) |
| |
| from pyfory.type import ( |
| int16_array, |
| int32_array, |
| int64_array, |
| float32_array, |
| float64_array, |
| BoolNDArrayType, |
| Int16NDArrayType, |
| Int32NDArrayType, |
| Int64NDArrayType, |
| Float32NDArrayType, |
| Float64NDArrayType, |
| TypeId, |
| infer_field, # Added infer_field |
| ) |
| |
| |
| class NoneSerializer(Serializer): |
| def __init__(self, fory): |
| super().__init__(fory, None) |
| self.need_to_write_ref = False |
| |
| def xwrite(self, buffer, value): |
| raise NotImplementedError |
| |
| def xread(self, buffer): |
| raise NotImplementedError |
| |
| def write(self, buffer, value): |
| pass |
| |
| def read(self, buffer): |
| return None |
| |
| |
| class PandasRangeIndexSerializer(Serializer): |
| __slots__ = "_cached" |
| |
| def __init__(self, fory): |
| import pandas as pd |
| |
| super().__init__(fory, pd.RangeIndex) |
| |
| def write(self, buffer, value): |
| fory = self.fory |
| start = value.start |
| stop = value.stop |
| step = value.step |
| if type(start) is int: |
| buffer.write_int16(NOT_NULL_INT64_FLAG) |
| buffer.write_varint64(start) |
| else: |
| if start is None: |
| buffer.write_int8(NULL_FLAG) |
| else: |
| buffer.write_int8(NOT_NULL_VALUE_FLAG) |
| fory.write_no_ref(buffer, start) |
| if type(stop) is int: |
| buffer.write_int16(NOT_NULL_INT64_FLAG) |
| buffer.write_varint64(stop) |
| else: |
| if stop is None: |
| buffer.write_int8(NULL_FLAG) |
| else: |
| buffer.write_int8(NOT_NULL_VALUE_FLAG) |
| fory.write_no_ref(buffer, stop) |
| if type(step) is int: |
| buffer.write_int16(NOT_NULL_INT64_FLAG) |
| buffer.write_varint64(step) |
| else: |
| if step is None: |
| buffer.write_int8(NULL_FLAG) |
| else: |
| buffer.write_int8(NOT_NULL_VALUE_FLAG) |
| fory.write_no_ref(buffer, step) |
| fory.write_ref(buffer, value.dtype) |
| fory.write_ref(buffer, value.name) |
| |
| def read(self, buffer): |
| if buffer.read_int8() == NULL_FLAG: |
| start = None |
| else: |
| start = self.fory.read_no_ref(buffer) |
| if buffer.read_int8() == NULL_FLAG: |
| stop = None |
| else: |
| stop = self.fory.read_no_ref(buffer) |
| if buffer.read_int8() == NULL_FLAG: |
| step = None |
| else: |
| step = self.fory.read_no_ref(buffer) |
| dtype = self.fory.read_ref(buffer) |
| name = self.fory.read_ref(buffer) |
| return self.type_(start, stop, step, dtype=dtype, name=name) |
| |
| def xwrite(self, buffer, value): |
| raise NotImplementedError |
| |
| def xread(self, buffer): |
| raise NotImplementedError |
| |
| |
| _jit_context = locals() |
| |
| |
| _ENABLE_FORY_PYTHON_JIT = os.environ.get("ENABLE_FORY_PYTHON_JIT", "True").lower() in ( |
| "true", |
| "1", |
| ) |
| |
| |
| from pyfory._struct import compute_struct_meta, StructFieldSerializerVisitor |
| |
| |
| class DataClassSerializer(Serializer): |
| def __init__( |
| self, |
| fory, |
| clz: type, |
| xlang: bool = False, |
| field_names: List[str] = None, |
| serializers: List[Serializer] = None, |
| nullable_fields: Dict[str, bool] = None, |
| ): |
| super().__init__(fory, clz) |
| self._xlang = xlang |
| from pyfory.type import unwrap_optional |
| |
| self._type_hints = typing.get_type_hints(clz) |
| self._field_names = field_names or self._get_field_names(clz) |
| self._has_slots = hasattr(clz, "__slots__") |
| self._nullable_fields = nullable_fields or {} |
| field_nullable = fory.field_nullable |
| if self._field_names and not self._nullable_fields: |
| for field_name in self._field_names: |
| if field_name in self._type_hints: |
| unwrapped_type, is_nullable = unwrap_optional(self._type_hints[field_name], field_nullable=field_nullable) |
| is_nullable = is_nullable or not is_primitive_type(unwrapped_type) |
| self._nullable_fields[field_name] = is_nullable |
| |
| # Cache unwrapped type hints |
| self._unwrapped_hints = self._compute_unwrapped_hints() |
| |
| if self._xlang: |
| self._serializers = serializers or [None] * len(self._field_names) |
| if serializers is None: |
| visitor = StructFieldSerializerVisitor(fory) |
| for index, key in enumerate(self._field_names): |
| unwrapped_type, _ = unwrap_optional(self._type_hints[key]) |
| serializer = infer_field(key, unwrapped_type, visitor, types_path=[]) |
| self._serializers[index] = serializer |
| self._hash, self._field_names, self._serializers = compute_struct_meta( |
| fory.type_resolver, self._field_names, self._serializers, self._nullable_fields |
| ) |
| self._generated_xwrite_method = self._gen_xwrite_method() |
| self._generated_xread_method = self._gen_xread_method() |
| if _ENABLE_FORY_PYTHON_JIT: |
| # don't use `__slots__`, which will make the instance method read-only |
| self.xwrite = self._generated_xwrite_method |
| self.xread = self._generated_xread_method |
| if self.fory.language == Language.PYTHON: |
| logger = logging.getLogger(__name__) |
| logger.warning( |
| "Type of class %s shouldn't be serialized using cross-language serializer", |
| clz, |
| ) |
| else: |
| # For non-xlang mode, use same infrastructure as xlang mode |
| # Python dataclass serialization follows the same spec as xlang |
| self._serializers = serializers or [None] * len(self._field_names) |
| if serializers is None: |
| visitor = StructFieldSerializerVisitor(fory) |
| for index, key in enumerate(self._field_names): |
| unwrapped_type, _ = unwrap_optional(self._type_hints[key]) |
| serializer = infer_field(key, unwrapped_type, visitor, types_path=[]) |
| self._serializers[index] = serializer |
| # In compatible mode, maintain stable field ordering (don't sort) |
| # In non-compatible mode, sort fields for consistent serialization |
| if not fory.compatible: |
| self._hash, self._field_names, self._serializers = compute_struct_meta( |
| fory.type_resolver, self._field_names, self._serializers, self._nullable_fields |
| ) |
| self._generated_write_method = self._gen_write_method() |
| self._generated_read_method = self._gen_read_method() |
| if _ENABLE_FORY_PYTHON_JIT: |
| # don't use `__slots__`, which will make instance method readonly |
| self.write = self._generated_write_method |
| self.read = self._generated_read_method |
| |
| def _get_field_names(self, clz): |
| if hasattr(clz, "__dict__"): |
| # Regular object with __dict__ |
| # For dataclasses, preserve field definition order |
| # In compatible mode, stable field ordering is critical for schema evolution |
| if dataclasses.is_dataclass(clz): |
| # Use dataclasses.fields() to get fields in definition order |
| return [field.name for field in dataclasses.fields(clz)] |
| # For non-dataclass objects, sort by key names for consistency |
| return sorted(self._type_hints.keys()) |
| elif hasattr(clz, "__slots__"): |
| # Object with __slots__ |
| return sorted(clz.__slots__) |
| return [] |
| |
| def _compute_unwrapped_hints(self): |
| """Compute unwrapped type hints once and cache.""" |
| from pyfory.type import unwrap_optional |
| |
| return {field_name: unwrap_optional(hint)[0] for field_name, hint in self._type_hints.items()} |
| |
| def _write_header(self, buffer): |
| """Write serialization header (hash or field count based on compatible mode).""" |
| if not self.fory.compatible: |
| buffer.write_int32(self._hash) |
| else: |
| buffer.write_varuint32(len(self._field_names)) |
| |
| def _read_header(self, buffer): |
| """Read serialization header and return number of fields written. |
| |
| Returns: |
| int: Number of fields that were written |
| |
| Raises: |
| TypeNotCompatibleError: If hash doesn't match in non-compatible mode |
| """ |
| if not self.fory.compatible: |
| hash_ = buffer.read_int32() |
| expected_hash = self._hash |
| if hash_ != expected_hash: |
| raise TypeNotCompatibleError(f"Hash {hash_} is not consistent with {expected_hash} for type {self.type_}") |
| return len(self._field_names) |
| else: |
| return buffer.read_varuint32() |
| |
| def _get_write_stmt_for_codegen(self, serializer, buffer, field_value): |
| """Generate write statement for code generation based on serializer type.""" |
| if isinstance(serializer, BooleanSerializer): |
| return f"{buffer}.write_bool({field_value})" |
| elif isinstance(serializer, ByteSerializer): |
| return f"{buffer}.write_int8({field_value})" |
| elif isinstance(serializer, Int16Serializer): |
| return f"{buffer}.write_int16({field_value})" |
| elif isinstance(serializer, Int32Serializer): |
| return f"{buffer}.write_varint32({field_value})" |
| elif isinstance(serializer, Int64Serializer): |
| return f"{buffer}.write_varint64({field_value})" |
| elif isinstance(serializer, Float32Serializer): |
| return f"{buffer}.write_float32({field_value})" |
| elif isinstance(serializer, Float64Serializer): |
| return f"{buffer}.write_float64({field_value})" |
| elif isinstance(serializer, StringSerializer): |
| return f"{buffer}.write_string({field_value})" |
| else: |
| return None # Complex type, needs ref handling |
| |
| def _get_read_stmt_for_codegen(self, serializer, buffer, field_value): |
| """Generate read statement for code generation based on serializer type.""" |
| if isinstance(serializer, BooleanSerializer): |
| return f"{field_value} = {buffer}.read_bool()" |
| elif isinstance(serializer, ByteSerializer): |
| return f"{field_value} = {buffer}.read_int8()" |
| elif isinstance(serializer, Int16Serializer): |
| return f"{field_value} = {buffer}.read_int16()" |
| elif isinstance(serializer, Int32Serializer): |
| return f"{field_value} = {buffer}.read_varint32()" |
| elif isinstance(serializer, Int64Serializer): |
| return f"{field_value} = {buffer}.read_varint64()" |
| elif isinstance(serializer, Float32Serializer): |
| return f"{field_value} = {buffer}.read_float32()" |
| elif isinstance(serializer, Float64Serializer): |
| return f"{field_value} = {buffer}.read_float64()" |
| elif isinstance(serializer, StringSerializer): |
| return f"{field_value} = {buffer}.read_string()" |
| else: |
| return None # Complex type, needs ref handling |
| |
| def _write_non_nullable_field(self, buffer, field_value, serializer): |
| """Write a non-nullable field value at runtime.""" |
| if isinstance(serializer, BooleanSerializer): |
| buffer.write_bool(field_value) |
| elif isinstance(serializer, ByteSerializer): |
| buffer.write_int8(field_value) |
| elif isinstance(serializer, Int16Serializer): |
| buffer.write_int16(field_value) |
| elif isinstance(serializer, Int32Serializer): |
| buffer.write_varint32(field_value) |
| elif isinstance(serializer, Int64Serializer): |
| buffer.write_varint64(field_value) |
| elif isinstance(serializer, Float32Serializer): |
| buffer.write_float32(field_value) |
| elif isinstance(serializer, Float64Serializer): |
| buffer.write_float64(field_value) |
| elif isinstance(serializer, StringSerializer): |
| buffer.write_string(field_value) |
| else: |
| self.fory.write_ref_pyobject(buffer, field_value) |
| |
| def _read_non_nullable_field(self, buffer, serializer): |
| """Read a non-nullable field value at runtime.""" |
| if isinstance(serializer, BooleanSerializer): |
| return buffer.read_bool() |
| elif isinstance(serializer, ByteSerializer): |
| return buffer.read_int8() |
| elif isinstance(serializer, Int16Serializer): |
| return buffer.read_int16() |
| elif isinstance(serializer, Int32Serializer): |
| return buffer.read_varint32() |
| elif isinstance(serializer, Int64Serializer): |
| return buffer.read_varint64() |
| elif isinstance(serializer, Float32Serializer): |
| return buffer.read_float32() |
| elif isinstance(serializer, Float64Serializer): |
| return buffer.read_float64() |
| elif isinstance(serializer, StringSerializer): |
| return buffer.read_string() |
| else: |
| return self.fory.read_ref_pyobject(buffer) |
| |
| def _write_nullable_field(self, buffer, field_value, serializer): |
| """Write a nullable field value at runtime.""" |
| if field_value is None: |
| buffer.write_int8(NULL_FLAG) |
| else: |
| buffer.write_int8(NOT_NULL_VALUE_FLAG) |
| if isinstance(serializer, StringSerializer): |
| buffer.write_string(field_value) |
| else: |
| self.fory.write_ref_pyobject(buffer, field_value) |
| |
| def _read_nullable_field(self, buffer, serializer): |
| """Read a nullable field value at runtime.""" |
| flag = buffer.read_int8() |
| if flag == NULL_FLAG: |
| return None |
| else: |
| if isinstance(serializer, StringSerializer): |
| return buffer.read_string() |
| else: |
| return self.fory.read_ref_pyobject(buffer) |
| |
| def _gen_write_method(self): |
| context = {} |
| counter = itertools.count(0) |
| buffer, fory, value, value_dict = "buffer", "fory", "value", "value_dict" |
| context[fory] = self.fory |
| context["_serializers"] = self._serializers |
| |
| stmts = [ |
| f'"""write method for {self.type_}"""', |
| ] |
| |
| # Write hash only in non-compatible mode; in compatible mode, write field count |
| if not self.fory.compatible: |
| stmts.append(f"{buffer}.write_int32({self._hash})") |
| else: |
| stmts.append(f"{buffer}.write_varuint32({len(self._field_names)})") |
| |
| if not self._has_slots: |
| stmts.append(f"{value_dict} = {value}.__dict__") |
| |
| # Write field values in order |
| for index, field_name in enumerate(self._field_names): |
| field_value = f"field_value{next(counter)}" |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| |
| if not self._has_slots: |
| stmts.append(f"{field_value} = {value_dict}['{field_name}']") |
| else: |
| stmts.append(f"{field_value} = {value}.{field_name}") |
| |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| # Use gen_write_nullable_basic_stmts for nullable basic types |
| if isinstance(serializer, BooleanSerializer): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, bool)) |
| elif isinstance(serializer, (ByteSerializer, Int16Serializer, Int32Serializer, Int64Serializer)): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, int)) |
| elif isinstance(serializer, (Float32Serializer, Float64Serializer)): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, float)) |
| elif isinstance(serializer, StringSerializer): |
| stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, str)) |
| else: |
| # For complex types, use write_ref_pyobject |
| stmts.append(f"{fory}.write_ref_pyobject({buffer}, {field_value})") |
| else: |
| stmt = self._get_write_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| stmt = f"{fory}.write_ref_pyobject({buffer}, {field_value})" |
| stmts.append(stmt) |
| |
| self._write_method_code, func = compile_function( |
| f"write_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer, value], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def _gen_read_method(self): |
| context = dict(_jit_context) |
| buffer, fory, obj_class, obj, obj_dict = ( |
| "buffer", |
| "fory", |
| "obj_class", |
| "obj", |
| "obj_dict", |
| ) |
| ref_resolver = "ref_resolver" |
| context[fory] = self.fory |
| context[obj_class] = self.type_ |
| context[ref_resolver] = self.fory.ref_resolver |
| context["_serializers"] = self._serializers |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| |
| stmts = [ |
| f'"""read method for {self.type_}"""', |
| ] |
| if not self.fory.strict: |
| context["checker"] = self.fory.policy |
| stmts.append(f"checker.authorize_instantiation({obj_class})") |
| |
| # Read hash only in non-compatible mode; in compatible mode, read field count |
| if not self.fory.compatible: |
| stmts.extend( |
| [ |
| f"read_hash = {buffer}.read_int32()", |
| f"if read_hash != {self._hash}:", |
| f""" raise TypeNotCompatibleError( |
| f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", |
| ] |
| ) |
| else: |
| stmts.append(f"num_fields_written = {buffer}.read_varuint32()") |
| |
| stmts.extend( |
| [ |
| f"{obj} = {obj_class}.__new__({obj_class})", |
| f"{ref_resolver}.reference({obj})", |
| ] |
| ) |
| |
| if not self._has_slots: |
| stmts.append(f"{obj_dict} = {obj}.__dict__") |
| |
| # Read field values in order |
| for index, field_name in enumerate(self._field_names): |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| field_value = f"field_value{index}" |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| # Build field reading statements |
| field_stmts = [] |
| |
| if is_nullable: |
| # Use gen_read_nullable_basic_stmts for nullable basic types |
| if isinstance(serializer, BooleanSerializer): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, bool, lambda v: f"{field_value} = {v}")) |
| elif isinstance(serializer, (ByteSerializer, Int16Serializer, Int32Serializer, Int64Serializer)): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, int, lambda v: f"{field_value} = {v}")) |
| elif isinstance(serializer, (Float32Serializer, Float64Serializer)): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, float, lambda v: f"{field_value} = {v}")) |
| elif isinstance(serializer, StringSerializer): |
| field_stmts.extend(gen_read_nullable_basic_stmts(buffer, str, lambda v: f"{field_value} = {v}")) |
| else: |
| # For complex types, use read_ref_pyobject |
| field_stmts.append(f"{field_value} = {fory}.read_ref_pyobject({buffer})") |
| else: |
| stmt = self._get_read_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| stmt = f"{field_value} = {fory}.read_ref_pyobject({buffer})" |
| field_stmts.append(stmt) |
| |
| # Set field value if it exists in current class |
| if field_name not in current_class_field_names: |
| field_stmts.append(f"# {field_name} is not in {self.type_}") |
| else: |
| if not self._has_slots: |
| field_stmts.append(f"{obj_dict}['{field_name}'] = {field_value}") |
| else: |
| field_stmts.append(f"{obj}.{field_name} = {field_value}") |
| |
| # In compatible mode, wrap field reading in a check |
| if self.fory.compatible: |
| stmts.append(f"if {index} < num_fields_written:") |
| # Indent all field statements |
| from pyfory.codegen import ident_lines |
| |
| field_stmts = ident_lines(field_stmts) |
| stmts.extend(field_stmts) |
| else: |
| stmts.extend(field_stmts) |
| |
| stmts.append(f"return {obj}") |
| self._read_method_code, func = compile_function( |
| f"read_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def _gen_xwrite_method(self): |
| context = {} |
| counter = itertools.count(0) |
| buffer, fory, value, value_dict = "buffer", "fory", "value", "value_dict" |
| context[fory] = self.fory |
| context["_serializers"] = self._serializers |
| stmts = [ |
| f'"""xwrite method for {self.type_}"""', |
| ] |
| if not self.fory.compatible: |
| stmts.append(f"{buffer}.write_int32({self._hash})") |
| if not self._has_slots: |
| stmts.append(f"{value_dict} = {value}.__dict__") |
| for index, field_name in enumerate(self._field_names): |
| field_value = f"field_value{next(counter)}" |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| if not self._has_slots: |
| stmts.append(f"{field_value} = {value_dict}['{field_name}']") |
| else: |
| stmts.append(f"{field_value} = {value}.{field_name}") |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| if isinstance(serializer, StringSerializer): |
| stmts.extend( |
| [ |
| f"if {field_value} is None:", |
| f" {buffer}.write_int8({NULL_FLAG})", |
| "else:", |
| f" {buffer}.write_int8({NOT_NULL_VALUE_FLAG})", |
| f" {buffer}.write_string({field_value})", |
| ] |
| ) |
| else: |
| stmts.append(f"{fory}.xwrite_ref({buffer}, {field_value}, serializer={serializer_var})") |
| else: |
| stmt = self._get_write_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| stmt = f"{fory}.xwrite_no_ref({buffer}, {field_value}, serializer={serializer_var})" |
| stmts.append(stmt) |
| self._xwrite_method_code, func = compile_function( |
| f"xwrite_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer, value], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def _gen_xread_method(self): |
| context = dict(_jit_context) |
| buffer, fory, obj_class, obj, obj_dict = ( |
| "buffer", |
| "fory", |
| "obj_class", |
| "obj", |
| "obj_dict", |
| ) |
| ref_resolver = "ref_resolver" |
| context[fory] = self.fory |
| context[obj_class] = self.type_ |
| context[ref_resolver] = self.fory.ref_resolver |
| context["_serializers"] = self._serializers |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| stmts = [ |
| f'"""xread method for {self.type_}"""', |
| ] |
| if not self.fory.strict: |
| context["checker"] = self.fory.policy |
| stmts.append(f"checker.authorize_instantiation({obj_class})") |
| if not self.fory.compatible: |
| stmts.extend( |
| [ |
| f"read_hash = {buffer}.read_int32()", |
| f"if read_hash != {self._hash}:", |
| f""" raise TypeNotCompatibleError( |
| f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", |
| ] |
| ) |
| stmts.extend( |
| [ |
| f"{obj} = {obj_class}.__new__({obj_class})", |
| f"{ref_resolver}.reference({obj})", |
| ] |
| ) |
| |
| if not self._has_slots: |
| stmts.append(f"{obj_dict} = {obj}.__dict__") |
| |
| for index, field_name in enumerate(self._field_names): |
| serializer_var = f"serializer{index}" |
| serializer = self._serializers[index] |
| context[serializer_var] = serializer |
| field_value = f"field_value{index}" |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| if isinstance(serializer, StringSerializer): |
| stmts.extend( |
| [ |
| f"if {buffer}.read_int8() >= {NOT_NULL_VALUE_FLAG}:", |
| f" {field_value} = {buffer}.read_string()", |
| "else:", |
| f" {field_value} = None", |
| ] |
| ) |
| else: |
| stmts.append(f"{field_value} = {fory}.xread_ref({buffer}, serializer={serializer_var})") |
| else: |
| stmt = self._get_read_stmt_for_codegen(serializer, buffer, field_value) |
| if stmt is None: |
| stmt = f"{field_value} = {fory}.xread_no_ref({buffer}, serializer={serializer_var})" |
| stmts.append(stmt) |
| if field_name not in current_class_field_names: |
| stmts.append(f"# {field_name} is not in {self.type_}") |
| continue |
| if not self._has_slots: |
| stmts.append(f"{obj_dict}['{field_name}'] = {field_value}") |
| else: |
| stmts.append(f"{obj}.{field_name} = {field_value}") |
| stmts.append(f"return {obj}") |
| self._xread_method_code, func = compile_function( |
| f"xread_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"), |
| [buffer], |
| stmts, |
| context, |
| ) |
| return func |
| |
| def write(self, buffer, value): |
| """Write dataclass instance to buffer in Python native format.""" |
| self._write_header(buffer) |
| |
| for index, field_name in enumerate(self._field_names): |
| field_value = getattr(value, field_name) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| if is_nullable: |
| self._write_nullable_field(buffer, field_value, serializer) |
| else: |
| self._write_non_nullable_field(buffer, field_value, serializer) |
| |
| def read(self, buffer): |
| """Read dataclass instance from buffer in Python native format.""" |
| num_fields_written = self._read_header(buffer) |
| |
| obj = self.type_.__new__(self.type_) |
| self.fory.ref_resolver.reference(obj) |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| |
| for index, field_name in enumerate(self._field_names): |
| # Only read if this field was written |
| if index >= num_fields_written: |
| break |
| |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| |
| if is_nullable: |
| field_value = self._read_nullable_field(buffer, serializer) |
| else: |
| field_value = self._read_non_nullable_field(buffer, serializer) |
| |
| if field_name in current_class_field_names: |
| setattr(obj, field_name, field_value) |
| return obj |
| |
| def xwrite(self, buffer: Buffer, value): |
| """Write dataclass instance to buffer in cross-language format.""" |
| if not self._xlang: |
| raise TypeError("xwrite can only be called when DataClassSerializer is in xlang mode") |
| if not self.fory.compatible: |
| buffer.write_int32(self._hash) |
| for index, field_name in enumerate(self._field_names): |
| field_value = getattr(value, field_name) |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable and field_value is None: |
| buffer.write_int8(-3) |
| else: |
| self.fory.xwrite_ref(buffer, field_value, serializer=serializer) |
| |
| def xread(self, buffer): |
| """Read dataclass instance from buffer in cross-language format.""" |
| if not self._xlang: |
| raise TypeError("xread can only be called when DataClassSerializer is in xlang mode") |
| if not self.fory.compatible: |
| hash_ = buffer.read_int32() |
| if hash_ != self._hash: |
| raise TypeNotCompatibleError( |
| f"Hash {hash_} is not consistent with {self._hash} for type {self.type_}", |
| ) |
| obj = self.type_.__new__(self.type_) |
| self.fory.ref_resolver.reference(obj) |
| current_class_field_names = set(self._get_field_names(self.type_)) |
| for index, field_name in enumerate(self._field_names): |
| serializer = self._serializers[index] |
| is_nullable = self._nullable_fields.get(field_name, False) |
| if is_nullable: |
| ref_id = buffer.read_int8() |
| if ref_id == -3: |
| field_value = None |
| else: |
| buffer.reader_index -= 1 |
| field_value = self.fory.xread_ref(buffer, serializer=serializer) |
| else: |
| field_value = self.fory.xread_ref(buffer, serializer=serializer) |
| if field_name in current_class_field_names: |
| setattr(obj, field_name, field_value) |
| return obj |
| |
| |
| class DataClassStubSerializer(DataClassSerializer): |
| def __init__(self, fory, clz: type, xlang: bool = False): |
| Serializer.__init__(self, fory, clz) |
| self.xlang = xlang |
| |
| def write(self, buffer, value): |
| self._replace().write(buffer, value) |
| |
| def read(self, buffer): |
| return self._replace().read(buffer) |
| |
| def xwrite(self, buffer, value): |
| self._replace().xwrite(buffer, value) |
| |
| def xread(self, buffer): |
| return self._replace().xread(buffer) |
| |
| def _replace(self): |
| typeinfo = self.fory.type_resolver.get_typeinfo(self.type_) |
| typeinfo.serializer = DataClassSerializer(self.fory, self.type_, self.xlang) |
| return typeinfo.serializer |
| |
| |
| # Use numpy array or python array module. |
| typecode_dict = ( |
| { |
| # use bytes serializer for byte array. |
| "h": (2, int16_array, TypeId.INT16_ARRAY), |
| "i": (4, int32_array, TypeId.INT32_ARRAY), |
| "l": (8, int64_array, TypeId.INT64_ARRAY), |
| "f": (4, float32_array, TypeId.FLOAT32_ARRAY), |
| "d": (8, float64_array, TypeId.FLOAT64_ARRAY), |
| } |
| if not _WINDOWS |
| else { |
| "h": (2, int16_array, TypeId.INT16_ARRAY), |
| "l": (4, int32_array, TypeId.INT32_ARRAY), |
| "q": (8, int64_array, TypeId.INT64_ARRAY), |
| "f": (4, float32_array, TypeId.FLOAT32_ARRAY), |
| "d": (8, float64_array, TypeId.FLOAT64_ARRAY), |
| } |
| ) |
| |
| typeid_code = ( |
| { |
| TypeId.INT16_ARRAY: "h", |
| TypeId.INT32_ARRAY: "i", |
| TypeId.INT64_ARRAY: "l", |
| TypeId.FLOAT32_ARRAY: "f", |
| TypeId.FLOAT64_ARRAY: "d", |
| } |
| if not _WINDOWS |
| else { |
| TypeId.INT16_ARRAY: "h", |
| TypeId.INT32_ARRAY: "l", |
| TypeId.INT64_ARRAY: "q", |
| TypeId.FLOAT32_ARRAY: "f", |
| TypeId.FLOAT64_ARRAY: "d", |
| } |
| ) |
| |
| |
| class PyArraySerializer(XlangCompatibleSerializer): |
| typecode_dict = typecode_dict |
| typecodearray_type = ( |
| { |
| "h": int16_array, |
| "i": int32_array, |
| "l": int64_array, |
| "f": float32_array, |
| "d": float64_array, |
| } |
| if not _WINDOWS |
| else { |
| "h": int16_array, |
| "l": int32_array, |
| "q": int64_array, |
| "f": float32_array, |
| "d": float64_array, |
| } |
| ) |
| |
| def __init__(self, fory, ftype, type_id: str): |
| super().__init__(fory, ftype) |
| self.typecode = typeid_code[type_id] |
| self.itemsize, ftype, self.type_id = typecode_dict[self.typecode] |
| |
| def xwrite(self, buffer, value): |
| assert value.itemsize == self.itemsize |
| view = memoryview(value) |
| assert view.format == self.typecode |
| assert view.itemsize == self.itemsize |
| assert view.c_contiguous # TODO handle contiguous |
| nbytes = len(value) * self.itemsize |
| buffer.write_varuint32(nbytes) |
| buffer.write_buffer(value) |
| |
| def xread(self, buffer): |
| data = buffer.read_bytes_and_size() |
| arr = array.array(self.typecode, []) |
| arr.frombytes(data) |
| return arr |
| |
| def write(self, buffer, value: array.array): |
| nbytes = len(value) * value.itemsize |
| buffer.write_string(value.typecode) |
| buffer.write_varuint32(nbytes) |
| buffer.write_buffer(value) |
| |
| def read(self, buffer): |
| typecode = buffer.read_string() |
| data = buffer.read_bytes_and_size() |
| arr = array.array(typecode[0], []) # Take first character |
| arr.frombytes(data) |
| return arr |
| |
| |
| class DynamicPyArraySerializer(Serializer): |
| """Serializer for dynamic Python arrays that handles any typecode.""" |
| |
| def __init__(self, fory, cls): |
| super().__init__(fory, cls) |
| self._serializer = ReduceSerializer(fory, cls) |
| |
| def xwrite(self, buffer, value): |
| itemsize, ftype, type_id = typecode_dict[value.typecode] |
| view = memoryview(value) |
| nbytes = len(value) * itemsize |
| buffer.write_varuint32(type_id) |
| buffer.write_varuint32(nbytes) |
| if not view.c_contiguous: |
| buffer.write_bytes(value.tobytes()) |
| else: |
| buffer.write_buffer(value) |
| |
| def xread(self, buffer): |
| type_id = buffer.read_varint32() |
| typecode = typeid_code[type_id] |
| data = buffer.read_bytes_and_size() |
| arr = array.array(typecode, []) |
| arr.frombytes(data) |
| return arr |
| |
| def write(self, buffer, value): |
| self._serializer.write(buffer, value) |
| |
| def read(self, buffer): |
| return self._serializer.read(buffer) |
| |
| |
| if np: |
| _np_dtypes_dict = ( |
| { |
| # use bytes serializer for byte array. |
| np.dtype(np.bool_): (1, "?", BoolNDArrayType, TypeId.BOOL_ARRAY), |
| np.dtype(np.int16): (2, "h", Int16NDArrayType, TypeId.INT16_ARRAY), |
| np.dtype(np.int32): (4, "i", Int32NDArrayType, TypeId.INT32_ARRAY), |
| np.dtype(np.int64): (8, "l", Int64NDArrayType, TypeId.INT64_ARRAY), |
| np.dtype(np.float32): (4, "f", Float32NDArrayType, TypeId.FLOAT32_ARRAY), |
| np.dtype(np.float64): (8, "d", Float64NDArrayType, TypeId.FLOAT64_ARRAY), |
| } |
| if not _WINDOWS |
| else { |
| np.dtype(np.bool_): (1, "?", BoolNDArrayType, TypeId.BOOL_ARRAY), |
| np.dtype(np.int16): (2, "h", Int16NDArrayType, TypeId.INT16_ARRAY), |
| np.dtype(np.int32): (4, "l", Int32NDArrayType, TypeId.INT32_ARRAY), |
| np.dtype(np.int64): (8, "q", Int64NDArrayType, TypeId.INT64_ARRAY), |
| np.dtype(np.float32): (4, "f", Float32NDArrayType, TypeId.FLOAT32_ARRAY), |
| np.dtype(np.float64): (8, "d", Float64NDArrayType, TypeId.FLOAT64_ARRAY), |
| } |
| ) |
| else: |
| _np_dtypes_dict = {} |
| |
| |
| class Numpy1DArraySerializer(Serializer): |
| dtypes_dict = _np_dtypes_dict |
| |
| def __init__(self, fory, ftype, dtype): |
| super().__init__(fory, ftype) |
| self.dtype = dtype |
| self.itemsize, self.format, self.typecode, self.type_id = _np_dtypes_dict[self.dtype] |
| self._serializer = ReduceSerializer(fory, np.ndarray) |
| |
| def xwrite(self, buffer, value): |
| assert value.itemsize == self.itemsize |
| view = memoryview(value) |
| try: |
| assert view.format == self.typecode |
| except AssertionError as e: |
| raise e |
| assert view.itemsize == self.itemsize |
| nbytes = len(value) * self.itemsize |
| buffer.write_varuint32(nbytes) |
| if self.dtype == np.dtype("bool") or not view.c_contiguous: |
| buffer.write_bytes(value.tobytes()) |
| else: |
| buffer.write_buffer(value) |
| |
| def xread(self, buffer): |
| data = buffer.read_bytes_and_size() |
| return np.frombuffer(data, dtype=self.dtype) |
| |
| def write(self, buffer, value): |
| self._serializer.write(buffer, value) |
| |
| def read(self, buffer): |
| return self._serializer.read(buffer) |
| |
| |
| class NDArraySerializer(Serializer): |
| def xwrite(self, buffer, value): |
| itemsize, typecode, ftype, type_id = _np_dtypes_dict[value.dtype] |
| view = memoryview(value) |
| nbytes = len(value) * itemsize |
| buffer.write_varuint32(type_id) |
| buffer.write_varuint32(nbytes) |
| if value.dtype == np.dtype("bool") or not view.c_contiguous: |
| buffer.write_bytes(value.tobytes()) |
| else: |
| buffer.write_buffer(value) |
| |
| def xread(self, buffer): |
| raise NotImplementedError("Multi-dimensional array not supported currently") |
| |
| def write(self, buffer, value): |
| fory = self.fory |
| dtype = value.dtype |
| fory.write_ref(buffer, dtype) |
| buffer.write_varuint32(len(value.shape)) |
| for dim in value.shape: |
| buffer.write_varuint32(dim) |
| if dtype.kind == "O": |
| buffer.write_varint32(len(value)) |
| for item in value: |
| fory.write_ref(buffer, item) |
| else: |
| fory.write_buffer_object(buffer, NDArrayBufferObject(value)) |
| |
| def read(self, buffer): |
| fory = self.fory |
| dtype = fory.read_ref(buffer) |
| ndim = buffer.read_varuint32() |
| shape = tuple(buffer.read_varuint32() for _ in range(ndim)) |
| if dtype.kind == "O": |
| length = buffer.read_varint32() |
| items = [fory.read_ref(buffer) for _ in range(length)] |
| return np.array(items, dtype=object) |
| fory_buf = fory.read_buffer_object(buffer) |
| if isinstance(fory_buf, memoryview): |
| return np.frombuffer(fory_buf, dtype=dtype).reshape(shape) |
| elif isinstance(fory_buf, bytes): |
| return np.frombuffer(fory_buf, dtype=dtype).reshape(shape) |
| return np.frombuffer(fory_buf.to_pybytes(), dtype=dtype).reshape(shape) |
| |
| |
| class BytesSerializer(XlangCompatibleSerializer): |
| def write(self, buffer, value): |
| self.fory.write_buffer_object(buffer, BytesBufferObject(value)) |
| |
| def read(self, buffer): |
| fory_buf = self.fory.read_buffer_object(buffer) |
| if isinstance(fory_buf, memoryview): |
| return bytes(fory_buf) |
| elif isinstance(fory_buf, bytes): |
| return fory_buf |
| return fory_buf.to_pybytes() |
| |
| |
| class BytesBufferObject(BufferObject): |
| __slots__ = ("binary",) |
| |
| def __init__(self, binary: bytes): |
| self.binary = binary |
| |
| def total_bytes(self) -> int: |
| return len(self.binary) |
| |
| def write_to(self, stream): |
| if hasattr(stream, "write_bytes"): |
| stream.write_bytes(self.binary) |
| else: |
| stream.write(self.binary) |
| |
| def getbuffer(self) -> memoryview: |
| return memoryview(self.binary) |
| |
| |
| class PickleBufferSerializer(XlangCompatibleSerializer): |
| def write(self, buffer, value): |
| self.fory.write_buffer_object(buffer, PickleBufferObject(value)) |
| |
| def read(self, buffer): |
| fory_buf = self.fory.read_buffer_object(buffer) |
| if isinstance(fory_buf, (bytes, memoryview, bytearray, Buffer)): |
| return pickle.PickleBuffer(fory_buf) |
| return pickle.PickleBuffer(fory_buf.to_pybytes()) |
| |
| |
| class PickleBufferObject(BufferObject): |
| __slots__ = ("pickle_buffer",) |
| |
| def __init__(self, pickle_buffer): |
| self.pickle_buffer = pickle_buffer |
| |
| def total_bytes(self) -> int: |
| return len(self.pickle_buffer.raw()) |
| |
| def write_to(self, stream): |
| raw = self.pickle_buffer.raw() |
| if hasattr(stream, "write_buffer"): |
| stream.write_buffer(raw) |
| else: |
| stream.write(bytes(raw) if isinstance(raw, memoryview) else raw) |
| |
| def getbuffer(self) -> memoryview: |
| raw = self.pickle_buffer.raw() |
| if isinstance(raw, memoryview): |
| return raw |
| return memoryview(bytes(raw)) |
| |
| |
| class NDArrayBufferObject(BufferObject): |
| __slots__ = ("array", "dtype", "shape") |
| |
| def __init__(self, array): |
| self.array = array |
| self.dtype = array.dtype |
| self.shape = array.shape |
| |
| def total_bytes(self) -> int: |
| return self.array.nbytes |
| |
| def write_to(self, stream): |
| data = self.array.tobytes() |
| if hasattr(stream, "write_buffer"): |
| stream.write_buffer(data) |
| else: |
| stream.write(data) |
| |
| def getbuffer(self) -> memoryview: |
| if self.array.flags.c_contiguous: |
| return memoryview(self.array.data) |
| return memoryview(self.array.tobytes()) |
| |
| |
| class StatefulSerializer(XlangCompatibleSerializer): |
| """ |
| Serializer for objects that support __getstate__ and __setstate__. |
| Uses Fory's native serialization for better cross-language support. |
| """ |
| |
| def __init__(self, fory, cls): |
| super().__init__(fory, cls) |
| self.cls = cls |
| # Cache the method references as fields in the serializer. |
| self._getnewargs_ex = getattr(cls, "__getnewargs_ex__", None) |
| self._getnewargs = getattr(cls, "__getnewargs__", None) |
| |
| def write(self, buffer, value): |
| state = value.__getstate__() |
| args = () |
| kwargs = {} |
| if self._getnewargs_ex is not None: |
| args, kwargs = self._getnewargs_ex(value) |
| elif self._getnewargs is not None: |
| args = self._getnewargs(value) |
| |
| # Serialize constructor arguments first |
| self.fory.write_ref(buffer, args) |
| self.fory.write_ref(buffer, kwargs) |
| |
| # Then serialize the state |
| self.fory.write_ref(buffer, state) |
| |
| def read(self, buffer): |
| fory = self.fory |
| args = fory.read_ref(buffer) |
| kwargs = fory.read_ref(buffer) |
| state = fory.read_ref(buffer) |
| |
| if args or kwargs: |
| # Case 1: __getnewargs__ was used. Re-create by calling __init__. |
| obj = self.cls(*args, **kwargs) |
| else: |
| # Case 2: Only __getstate__ was used. Create without calling __init__. |
| obj = self.cls.__new__(self.cls) |
| |
| if state: |
| fory.policy.intercept_setstate(obj, state) |
| obj.__setstate__(state) |
| return obj |
| |
| |
| class ReduceSerializer(XlangCompatibleSerializer): |
| """ |
| Serializer for objects that support __reduce__ or __reduce_ex__. |
| Uses Fory's native serialization for better cross-language support. |
| Has higher precedence than StatefulSerializer. |
| """ |
| |
| def __init__(self, fory, cls): |
| super().__init__(fory, cls) |
| self.cls = cls |
| # Cache the method references as fields in the serializer. |
| self._reduce_ex = getattr(cls, "__reduce_ex__", None) |
| self._reduce = getattr(cls, "__reduce__", None) |
| self._getnewargs_ex = getattr(cls, "__getnewargs_ex__", None) |
| self._getnewargs = getattr(cls, "__getnewargs__", None) |
| |
| def write(self, buffer, value): |
| # Try __reduce_ex__ first (with protocol 5 for pickle5 out-of-band buffer support), then __reduce__ |
| # Check if the object has a custom __reduce_ex__ method (not just the default from object) |
| if hasattr(value, "__reduce_ex__") and value.__class__.__reduce_ex__ is not object.__reduce_ex__: |
| try: |
| reduce_result = value.__reduce_ex__(5) |
| except TypeError: |
| # Some objects don't support protocol argument |
| reduce_result = value.__reduce_ex__() |
| elif hasattr(value, "__reduce__"): |
| reduce_result = value.__reduce__() |
| else: |
| raise ValueError(f"Object {value} has no __reduce__ or __reduce_ex__ method") |
| |
| # Handle different __reduce__ return formats |
| if isinstance(reduce_result, str): |
| # Case 1: Just a global name (simple case) |
| reduce_data = (0, reduce_result) |
| elif isinstance(reduce_result, tuple): |
| if len(reduce_result) == 2: |
| # Case 2: (callable, args) |
| callable_obj, args = reduce_result |
| reduce_data = (1, callable_obj, args) |
| elif len(reduce_result) == 3: |
| # Case 3: (callable, args, state) |
| callable_obj, args, state = reduce_result |
| reduce_data = (1, callable_obj, args, state) |
| elif len(reduce_result) == 4: |
| # Case 4: (callable, args, state, listitems) |
| callable_obj, args, state, listitems = reduce_result |
| reduce_data = (1, callable_obj, args, state, listitems) |
| elif len(reduce_result) == 5: |
| # Case 5: (callable, args, state, listitems, dictitems) |
| callable_obj, args, state, listitems, dictitems = reduce_result |
| reduce_data = ( |
| 1, |
| callable_obj, |
| args, |
| state, |
| listitems, |
| dictitems, |
| ) |
| else: |
| raise ValueError(f"Invalid __reduce__ result length: {len(reduce_result)}") |
| else: |
| raise ValueError(f"Invalid __reduce__ result type: {type(reduce_result)}") |
| buffer.write_varuint32(len(reduce_data)) |
| fory = self.fory |
| for item in reduce_data: |
| fory.write_ref(buffer, item) |
| |
| def read(self, buffer): |
| reduce_data_num_items = buffer.read_varuint32() |
| assert reduce_data_num_items <= 6, buffer |
| reduce_data = [None] * 6 |
| fory = self.fory |
| for i in range(reduce_data_num_items): |
| reduce_data[i] = fory.read_ref(buffer) |
| |
| if reduce_data[0] == 0: |
| # Case 1: Global name |
| global_name = reduce_data[1] |
| # Import and return the global object |
| if "." in global_name: |
| module_name, obj_name = global_name.rsplit(".", 1) |
| module = __import__(module_name, fromlist=[obj_name]) |
| return getattr(module, obj_name) |
| else: |
| # Handle case where global_name doesn't contain a dot |
| # This might be a built-in type or a simple name |
| try: |
| import builtins |
| |
| return getattr(builtins, global_name) |
| except AttributeError: |
| raise ValueError(f"Cannot resolve global name: {global_name}") |
| elif reduce_data[0] == 1: |
| # Case 2-5: Callable with args and optional state/items |
| callable_obj = reduce_data[1] |
| args = reduce_data[2] or () |
| state = reduce_data[3] |
| listitems = reduce_data[4] |
| dictitems = reduce_data[5] if len(reduce_data) > 5 else None |
| |
| obj = fory.policy.intercept_reduce_call(callable_obj, args) |
| if obj is None: |
| # Create the object using the callable and args |
| obj = callable_obj(*args) |
| |
| # Restore state if present |
| if state is not None: |
| if hasattr(obj, "__setstate__"): |
| obj.__setstate__(state) |
| else: |
| # Fallback: update __dict__ directly |
| if hasattr(obj, "__dict__"): |
| obj.__dict__.update(state) |
| |
| # Restore list items if present |
| if listitems is not None: |
| obj.extend(listitems) |
| |
| # Restore dict items if present |
| if dictitems is not None: |
| for key, value in dictitems: |
| obj[key] = value |
| |
| result = fory.policy.inspect_reduced_object(obj) |
| if result is not None: |
| obj = result |
| return obj |
| else: |
| raise ValueError(f"Invalid reduce data format flag: {reduce_data[0]}") |
| |
| |
| __skip_class_attr_names__ = ("__module__", "__qualname__", "__dict__", "__weakref__") |
| |
| |
| class TypeSerializer(Serializer): |
| """Serializer for Python type objects (classes), including local classes.""" |
| |
| def __init__(self, fory, cls): |
| super().__init__(fory, cls) |
| self.cls = cls |
| |
| def write(self, buffer, value): |
| module_name = value.__module__ |
| qualname = value.__qualname__ |
| |
| if module_name == "__main__" or "<locals>" in qualname: |
| # Local class - serialize full context |
| buffer.write_int8(1) # Local class marker |
| self._serialize_local_class(buffer, value) |
| else: |
| buffer.write_int8(0) # Global class marker |
| buffer.write_string(module_name) |
| buffer.write_string(qualname) |
| |
| def read(self, buffer): |
| class_type = buffer.read_int8() |
| |
| if class_type == 1: |
| # Local class - deserialize from full context |
| return self._deserialize_local_class(buffer) |
| else: |
| # Global class - import by module and name |
| module_name = buffer.read_string() |
| qualname = buffer.read_string() |
| cls = importlib.import_module(module_name) |
| for name in qualname.split("."): |
| cls = getattr(cls, name) |
| result = self.fory.policy.validate_class(cls, is_local=False) |
| if result is not None: |
| cls = result |
| return cls |
| |
| def _serialize_local_class(self, buffer, cls): |
| """Serialize a local class by capturing its creation context.""" |
| assert self.fory.ref_tracking, "Reference tracking must be enabled for local classes serialization" |
| # Basic class information |
| module = cls.__module__ |
| qualname = cls.__qualname__ |
| buffer.write_string(module) |
| buffer.write_string(qualname) |
| fory = self.fory |
| |
| # Serialize base classes |
| # Let Fory's normal serialization handle bases (including other local classes) |
| bases = cls.__bases__ |
| buffer.write_varuint32(len(bases)) |
| for base in bases: |
| fory.write_ref(buffer, base) |
| |
| # Serialize class dictionary (excluding special attributes) |
| # FunctionSerializer will automatically handle methods with closures |
| class_dict = {} |
| attr_names, class_methods = [], [] |
| for attr_name, attr_value in cls.__dict__.items(): |
| # Skip special attributes that are handled by type() constructor |
| if attr_name in __skip_class_attr_names__: |
| continue |
| if isinstance(attr_value, classmethod): |
| attr_names.append(attr_name) |
| class_methods.append(attr_value) |
| else: |
| class_dict[attr_name] = attr_value |
| # serialize method specially to avoid circular deps in method deserialization |
| buffer.write_varuint32(len(class_methods)) |
| for i in range(len(class_methods)): |
| buffer.write_string(attr_names[i]) |
| class_method = class_methods[i] |
| fory.write_ref(buffer, class_method.__func__) |
| |
| # Let Fory's normal serialization handle the class dict |
| # This will use FunctionSerializer for methods, which handles closures properly |
| fory.write_ref(buffer, class_dict) |
| |
| def _deserialize_local_class(self, buffer): |
| """Deserialize a local class by recreating it with the captured context.""" |
| fory = self.fory |
| assert fory.ref_tracking, "Reference tracking must be enabled for local classes deserialization" |
| # Read basic class information |
| module = buffer.read_string() |
| qualname = buffer.read_string() |
| name = qualname.rsplit(".", 1)[-1] |
| ref_id = fory.ref_resolver.last_preserved_ref_id() |
| |
| # Read base classes |
| num_bases = buffer.read_varuint32() |
| bases = tuple([fory.read_ref(buffer) for _ in range(num_bases)]) |
| # Create the class using type() constructor |
| cls = type(name, bases, {}) |
| # `class_dict` may reference to `cls`, which is a circular reference |
| fory.ref_resolver.set_read_object(ref_id, cls) |
| |
| # classmethods |
| for i in range(buffer.read_varuint32()): |
| attr_name = buffer.read_string() |
| func = fory.read_ref(buffer) |
| method = types.MethodType(func, cls) |
| setattr(cls, attr_name, method) |
| # Read class dictionary |
| # Fory's normal deserialization will handle methods via FunctionSerializer |
| class_dict = fory.read_ref(buffer) |
| for k, v in class_dict.items(): |
| setattr(cls, k, v) |
| |
| # Set module and qualname |
| cls.__module__ = module |
| cls.__qualname__ = qualname |
| result = fory.policy.validate_class(cls, is_local=True) |
| if result is not None: |
| cls = result |
| return cls |
| |
| |
| class ModuleSerializer(Serializer): |
| """Serializer for python module""" |
| |
| def __init__(self, fory): |
| super().__init__(fory, types.ModuleType) |
| |
| def write(self, buffer, value): |
| buffer.write_string(value.__name__) |
| |
| def read(self, buffer): |
| mod = buffer.read_string() |
| mod = importlib.import_module(mod) |
| result = self.fory.policy.validate_module(mod.__name__) |
| if result is not None: |
| mod = result |
| return mod |
| |
| |
| class MappingProxySerializer(Serializer): |
| def __init__(self, fory): |
| super().__init__(fory, types.MappingProxyType) |
| |
| def write(self, buffer, value): |
| self.fory.write_ref(buffer, dict(value)) |
| |
| def read(self, buffer): |
| return types.MappingProxyType(self.fory.read_ref(buffer)) |
| |
| |
| class FunctionSerializer(XlangCompatibleSerializer): |
| """Serializer for function objects |
| |
| This serializer captures all the necessary information to recreate a function: |
| - Function code |
| - Function name |
| - Module name |
| - Closure variables |
| - Global variables |
| - Default arguments |
| - Function attributes |
| |
| The code object is serialized with marshal, and all other components |
| (defaults, globals, closure cells, attrs) go through Fory’s own |
| write_ref/read_ref pipeline to ensure proper type registration |
| and reference tracking. |
| """ |
| |
| # Cache for function attributes that are handled separately |
| _FUNCTION_ATTRS = frozenset( |
| ( |
| "__code__", |
| "__name__", |
| "__defaults__", |
| "__closure__", |
| "__globals__", |
| "__module__", |
| "__qualname__", |
| ) |
| ) |
| |
| def _serialize_function(self, buffer, func): |
| """Serialize a function by capturing all its components.""" |
| # Get function metadata |
| instance = getattr(func, "__self__", None) |
| if instance is not None and not inspect.ismodule(instance): |
| # Handle bound methods |
| self_obj = instance |
| func_name = func.__name__ |
| # Serialize as a tuple (is_method, self_obj, method_name) |
| buffer.write_int8(0) # is a method |
| # For the 'self' object, we need to use fory's serialization |
| self.fory.write_ref(buffer, self_obj) |
| buffer.write_string(func_name) |
| return |
| |
| # Regular function or lambda |
| code = func.__code__ |
| module = func.__module__ |
| qualname = func.__qualname__ |
| |
| if "<locals>" not in qualname and module != "__main__": |
| buffer.write_int8(1) # Not a method |
| buffer.write_string(module) |
| buffer.write_string(qualname) |
| return |
| |
| # Serialize function metadata |
| buffer.write_int8(2) # Not a method |
| buffer.write_string(module) |
| buffer.write_string(qualname) |
| |
| defaults = func.__defaults__ |
| closure = func.__closure__ |
| globals_dict = func.__globals__ |
| |
| # Instead of trying to serialize the code object in parts, use marshal |
| # which is specifically designed for code objects |
| marshalled_code = marshal.dumps(code) |
| buffer.write_bytes_and_size(marshalled_code) |
| |
| # Serialize defaults (or None if no defaults) |
| # Write whether defaults exist |
| buffer.write_bool(defaults is not None) |
| if defaults is not None: |
| # Write the number of default arguments |
| buffer.write_varuint32(len(defaults)) |
| # Serialize each default value individually |
| for default_value in defaults: |
| self.fory.write_ref(buffer, default_value) |
| |
| # Handle closure |
| # We need to serialize both the closure values and the fact that there is a closure |
| # The code object's co_freevars tells us what variables are in the closure |
| buffer.write_bool(closure is not None) |
| buffer.write_varuint32(len(code.co_freevars) if code.co_freevars else 0) |
| |
| if closure: |
| # Extract and serialize each closure cell's contents |
| for cell in closure: |
| self.fory.write_ref(buffer, cell.cell_contents) |
| |
| # Serialize free variable names as a list of strings |
| # Convert tuple to list since tuple might not be registered |
| freevars_list = list(code.co_freevars) if code.co_freevars else [] |
| buffer.write_varuint32(len(freevars_list)) |
| for name in freevars_list: |
| buffer.write_string(name) |
| |
| # Handle globals |
| # Identify which globals are actually used by the function |
| global_names = set() |
| for name in code.co_names: |
| if name in globals_dict and not hasattr(builtins, name): |
| global_names.add(name) |
| |
| # Add any globals referenced by nested functions in co_consts |
| for const in code.co_consts: |
| if isinstance(const, types.CodeType): |
| for name in const.co_names: |
| if name in globals_dict and not hasattr(builtins, name): |
| global_names.add(name) |
| |
| # Create and serialize a dictionary with only the necessary globals |
| globals_to_serialize = {name: globals_dict[name] for name in global_names if name in globals_dict} |
| self.fory.write_ref(buffer, globals_to_serialize) |
| |
| # Handle additional attributes |
| attrs = {} |
| for attr in dir(func): |
| if attr.startswith("__") and attr.endswith("__"): |
| continue |
| if attr in self._FUNCTION_ATTRS: |
| continue |
| try: |
| attrs[attr] = getattr(func, attr) |
| except (AttributeError, TypeError): |
| pass |
| |
| self.fory.write_ref(buffer, attrs) |
| |
| def _deserialize_function(self, buffer): |
| """Deserialize a function from its components.""" |
| |
| # Check if it's a method |
| func_type_id = buffer.read_int8() |
| if func_type_id == 0: |
| # Handle bound methods |
| self_obj = self.fory.read_ref(buffer) |
| method_name = buffer.read_string() |
| func = getattr(self_obj, method_name) |
| result = self.fory.policy.validate_function(func, is_local=False) |
| if result is not None: |
| func = result |
| return func |
| |
| if func_type_id == 1: |
| module = buffer.read_string() |
| qualname = buffer.read_string() |
| mod = importlib.import_module(module) |
| for name in qualname.split("."): |
| mod = getattr(mod, name) |
| result = self.fory.policy.validate_function(mod, is_local=False) |
| if result is not None: |
| mod = result |
| return mod |
| |
| # Regular function or lambda |
| module = buffer.read_string() |
| qualname = buffer.read_string() |
| name = qualname.rsplit(".")[-1] |
| |
| # Use marshal to load the code object, which handles all Python versions correctly |
| marshalled_code = buffer.read_bytes_and_size() |
| code = marshal.loads(marshalled_code) |
| |
| # Deserialize defaults |
| has_defaults = buffer.read_bool() |
| defaults = None |
| if has_defaults: |
| # Read the number of default arguments |
| num_defaults = buffer.read_varuint32() |
| # Deserialize each default value |
| default_values = [] |
| for _ in range(num_defaults): |
| default_values.append(self.fory.read_ref(buffer)) |
| defaults = tuple(default_values) |
| |
| # Handle closure |
| has_closure = buffer.read_bool() |
| num_freevars = buffer.read_varuint32() |
| closure = None |
| |
| # Read closure values if there are any |
| closure_values = [] |
| if has_closure: |
| for _ in range(num_freevars): |
| closure_values.append(self.fory.read_ref(buffer)) |
| |
| # Create closure cells |
| closure = tuple(types.CellType(value) for value in closure_values) |
| |
| # Read free variable names from strings |
| num_freevars = buffer.read_varuint32() |
| freevars = [] |
| for _ in range(num_freevars): |
| freevars.append(buffer.read_string()) |
| |
| # Handle globals |
| globals_dict = self.fory.read_ref(buffer) |
| |
| # Create a globals dictionary with module's globals as the base |
| func_globals = {} |
| try: |
| mod = importlib.import_module(module) |
| if mod: |
| func_globals.update(mod.__dict__) |
| except (KeyError, AttributeError): |
| pass |
| |
| # Add the deserialized globals |
| func_globals.update(globals_dict) |
| |
| # Ensure __builtins__ is available |
| if "__builtins__" not in func_globals: |
| func_globals["__builtins__"] = builtins |
| |
| # Create function |
| func = types.FunctionType(code, func_globals, name, defaults, closure) |
| |
| # Set function attributes |
| func.__module__ = module |
| func.__qualname__ = qualname |
| |
| # Deserialize and set additional attributes |
| attrs = self.fory.read_ref(buffer) |
| for attr_name, attr_value in attrs.items(): |
| setattr(func, attr_name, attr_value) |
| |
| result = self.fory.policy.validate_function(func, is_local=True) |
| if result is not None: |
| func = result |
| return func |
| |
| def xwrite(self, buffer, value): |
| raise NotImplementedError() |
| |
| def xread(self, buffer): |
| raise NotImplementedError() |
| |
| def write(self, buffer, value): |
| """Serialize a function for Python-only mode.""" |
| self._serialize_function(buffer, value) |
| |
| def read(self, buffer): |
| """Deserialize a function for Python-only mode.""" |
| return self._deserialize_function(buffer) |
| |
| |
| class NativeFuncMethodSerializer(Serializer): |
| def write(self, buffer, func): |
| name = func.__name__ |
| buffer.write_string(name) |
| obj = getattr(func, "__self__", None) |
| if obj is None or inspect.ismodule(obj): |
| buffer.write_bool(True) |
| module = func.__module__ |
| buffer.write_string(module) |
| else: |
| buffer.write_bool(False) |
| self.fory.write_ref(buffer, obj) |
| |
| def read(self, buffer): |
| name = buffer.read_string() |
| if buffer.read_bool(): |
| module = buffer.read_string() |
| mod = importlib.import_module(module) |
| func = getattr(mod, name) |
| else: |
| obj = self.fory.read_ref(buffer) |
| func = getattr(obj, name) |
| result = self.fory.policy.validate_function(func, is_local=False) |
| if result is not None: |
| func = result |
| return func |
| |
| |
| class MethodSerializer(Serializer): |
| """Serializer for bound method objects.""" |
| |
| def __init__(self, fory, cls): |
| super().__init__(fory, cls) |
| self.cls = cls |
| |
| def write(self, buffer, value): |
| # Serialize bound method as (instance, method_name) |
| instance = value.__self__ |
| method_name = value.__func__.__name__ |
| |
| self.fory.write_ref(buffer, instance) |
| buffer.write_string(method_name) |
| |
| def read(self, buffer): |
| instance = self.fory.read_ref(buffer) |
| method_name = buffer.read_string() |
| |
| method = getattr(instance, method_name) |
| cls = method.__self__.__class__ |
| is_local = cls.__module__ == "__main__" or "<locals>" in cls.__qualname__ |
| result = self.fory.policy.validate_method(method, is_local=is_local) |
| if result is not None: |
| method = result |
| return method |
| |
| def xwrite(self, buffer, value): |
| return self.write(buffer, value) |
| |
| def xread(self, buffer): |
| return self.read(buffer) |
| |
| |
| class ObjectSerializer(Serializer): |
| """Serializer for regular Python objects. |
| It serializes objects based on `__dict__` or `__slots__`. |
| """ |
| |
| def __init__(self, fory, clz: type): |
| super().__init__(fory, clz) |
| # If the class defines __slots__, compute and store a sorted list once |
| slots = getattr(clz, "__slots__", None) |
| self._slot_field_names = None |
| if slots is not None: |
| # __slots__ can be a string or iterable of strings |
| if isinstance(slots, str): |
| slots = [slots] |
| self._slot_field_names = sorted(slots) |
| |
| def write(self, buffer, value): |
| # Use precomputed slots if available, otherwise sort instance __dict__ keys |
| if self._slot_field_names is not None: |
| sorted_field_names = self._slot_field_names |
| else: |
| sorted_field_names = sorted(value.__dict__.keys()) |
| |
| buffer.write_varuint32(len(sorted_field_names)) |
| for field_name in sorted_field_names: |
| buffer.write_string(field_name) |
| field_value = getattr(value, field_name) |
| self.fory.write_ref(buffer, field_value) |
| |
| def read(self, buffer): |
| fory = self.fory |
| fory.policy.authorize_instantiation(self.type_) |
| obj = self.type_.__new__(self.type_) |
| fory.ref_resolver.reference(obj) |
| num_fields = buffer.read_varuint32() |
| for _ in range(num_fields): |
| field_name = buffer.read_string() |
| field_value = fory.read_ref(buffer) |
| setattr(obj, field_name, field_value) |
| return obj |
| |
| def xwrite(self, buffer, value): |
| # for cross-language or minimal framing, reuse the same logic |
| return self.write(buffer, value) |
| |
| def xread(self, buffer): |
| # symmetric to xwrite |
| return self.read(buffer) |
| |
| |
| @dataclasses.dataclass |
| class NonExistEnum: |
| value: int = -1 |
| name: str = "" |
| |
| |
| class NonExistEnumSerializer(Serializer): |
| def __init__(self, fory): |
| super().__init__(fory, NonExistEnum) |
| self.need_to_write_ref = False |
| |
| @classmethod |
| def support_subclass(cls) -> bool: |
| return True |
| |
| def write(self, buffer, value): |
| buffer.write_string(value.name) |
| |
| def read(self, buffer): |
| name = buffer.read_string() |
| return NonExistEnum(name=name) |
| |
| def xwrite(self, buffer, value): |
| buffer.write_varuint32(value.value) |
| |
| def xread(self, buffer): |
| value = buffer.read_varuint32() |
| return NonExistEnum(value=value) |
| |
| |
| class UnsupportedSerializer(Serializer): |
| def write(self, buffer, value): |
| self.fory.handle_unsupported_write(value) |
| |
| def read(self, buffer): |
| return self.fory.handle_unsupported_read(buffer) |
| |
| def xwrite(self, buffer, value): |
| raise NotImplementedError(f"{self.type_} is not supported for xwrite") |
| |
| def xread(self, buffer): |
| raise NotImplementedError(f"{self.type_} is not supported for xread") |