blob: f0785b0541a745c72fc2ca31aaf7c27584f02e70 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Python code generator."""
import keyword
from typing import List, Optional, Set
from fory_compiler.generators.base import BaseGenerator, GeneratedFile
from fory_compiler.ir.ast import (
Message,
Enum,
Union,
Field,
FieldType,
PrimitiveType,
NamedType,
ListType,
MapType,
)
from fory_compiler.ir.types import PrimitiveKind
class PythonGenerator(BaseGenerator):
"""Generates Python dataclasses with pyfory type hints."""
language_name = "python"
file_extension = ".py"
# Mapping from FDL primitive types to Python types
PRIMITIVE_MAP = {
PrimitiveKind.BOOL: "bool",
PrimitiveKind.INT8: "pyfory.int8",
PrimitiveKind.INT16: "pyfory.int16",
PrimitiveKind.INT32: "pyfory.fixed_int32",
PrimitiveKind.VARINT32: "pyfory.int32",
PrimitiveKind.INT64: "pyfory.fixed_int64",
PrimitiveKind.VARINT64: "pyfory.int64",
PrimitiveKind.TAGGED_INT64: "pyfory.tagged_int64",
PrimitiveKind.UINT8: "pyfory.uint8",
PrimitiveKind.UINT16: "pyfory.uint16",
PrimitiveKind.UINT32: "pyfory.fixed_uint32",
PrimitiveKind.VAR_UINT32: "pyfory.uint32",
PrimitiveKind.UINT64: "pyfory.fixed_uint64",
PrimitiveKind.VAR_UINT64: "pyfory.uint64",
PrimitiveKind.TAGGED_UINT64: "pyfory.tagged_uint64",
PrimitiveKind.FLOAT16: "pyfory.float32",
PrimitiveKind.FLOAT32: "pyfory.float32",
PrimitiveKind.FLOAT64: "pyfory.float64",
PrimitiveKind.STRING: "str",
PrimitiveKind.BYTES: "bytes",
PrimitiveKind.DATE: "datetime.date",
PrimitiveKind.TIMESTAMP: "datetime.datetime",
PrimitiveKind.ANY: "Any",
}
# Numpy dtype strings for primitive arrays
NUMPY_DTYPE_MAP = {
PrimitiveKind.BOOL: "np.bool_",
PrimitiveKind.INT8: "np.int8",
PrimitiveKind.INT16: "np.int16",
PrimitiveKind.INT32: "np.int32",
PrimitiveKind.VARINT32: "np.int32",
PrimitiveKind.INT64: "np.int64",
PrimitiveKind.VARINT64: "np.int64",
PrimitiveKind.TAGGED_INT64: "np.int64",
PrimitiveKind.UINT8: "np.uint8",
PrimitiveKind.UINT16: "np.uint16",
PrimitiveKind.UINT32: "np.uint32",
PrimitiveKind.VAR_UINT32: "np.uint32",
PrimitiveKind.UINT64: "np.uint64",
PrimitiveKind.VAR_UINT64: "np.uint64",
PrimitiveKind.TAGGED_UINT64: "np.uint64",
PrimitiveKind.FLOAT16: "np.float32",
PrimitiveKind.FLOAT32: "np.float32",
PrimitiveKind.FLOAT64: "np.float64",
}
ARRAY_TYPE_HINTS = {
PrimitiveKind.BOOL: "pyfory.bool_ndarray",
PrimitiveKind.INT8: "pyfory.int8_ndarray",
PrimitiveKind.INT16: "pyfory.int16_ndarray",
PrimitiveKind.INT32: "pyfory.int32_ndarray",
PrimitiveKind.VARINT32: "pyfory.int32_ndarray",
PrimitiveKind.INT64: "pyfory.int64_ndarray",
PrimitiveKind.VARINT64: "pyfory.int64_ndarray",
PrimitiveKind.TAGGED_INT64: "pyfory.int64_ndarray",
PrimitiveKind.UINT8: "pyfory.uint8_ndarray",
PrimitiveKind.UINT16: "pyfory.uint16_ndarray",
PrimitiveKind.UINT32: "pyfory.uint32_ndarray",
PrimitiveKind.VAR_UINT32: "pyfory.uint32_ndarray",
PrimitiveKind.UINT64: "pyfory.uint64_ndarray",
PrimitiveKind.VAR_UINT64: "pyfory.uint64_ndarray",
PrimitiveKind.TAGGED_UINT64: "pyfory.uint64_ndarray",
PrimitiveKind.FLOAT16: "pyfory.float32_ndarray",
PrimitiveKind.FLOAT32: "pyfory.float32_ndarray",
PrimitiveKind.FLOAT64: "pyfory.float64_ndarray",
}
# Default values for primitive types
DEFAULT_VALUES = {
PrimitiveKind.BOOL: "False",
PrimitiveKind.INT8: "0",
PrimitiveKind.INT16: "0",
PrimitiveKind.INT32: "0",
PrimitiveKind.VARINT32: "0",
PrimitiveKind.INT64: "0",
PrimitiveKind.VARINT64: "0",
PrimitiveKind.TAGGED_INT64: "0",
PrimitiveKind.UINT8: "0",
PrimitiveKind.UINT16: "0",
PrimitiveKind.UINT32: "0",
PrimitiveKind.VAR_UINT32: "0",
PrimitiveKind.UINT64: "0",
PrimitiveKind.VAR_UINT64: "0",
PrimitiveKind.TAGGED_UINT64: "0",
PrimitiveKind.FLOAT16: "0.0",
PrimitiveKind.FLOAT32: "0.0",
PrimitiveKind.FLOAT64: "0.0",
PrimitiveKind.STRING: '""',
PrimitiveKind.BYTES: 'b""',
PrimitiveKind.DATE: "None",
PrimitiveKind.TIMESTAMP: "None",
PrimitiveKind.ANY: "None",
}
def safe_name(self, name: str) -> str:
"""Return a Python-safe identifier."""
if keyword.iskeyword(name):
return f"{name}_"
return name
def generate(self) -> List[GeneratedFile]:
"""Generate Python files for the schema."""
files = []
# Generate a single module with all types
files.append(self.generate_module())
return files
def get_module_name(self) -> str:
"""Get the Python module name."""
if self.package:
return self.package.replace(".", "_")
return "generated"
def generate_module(self) -> GeneratedFile:
"""Generate a Python module with all types."""
lines = []
imports: Set[str] = set()
# Collect all imports
imports.add("from dataclasses import dataclass, field")
imports.add("from enum import Enum, IntEnum")
imports.add("from typing import Dict, List, Optional, cast")
imports.add("import pyfory")
if self.schema_has_ref_elements():
imports.add("from pyfory import Ref")
for message in self.schema.messages:
self.collect_message_imports(message, imports)
for union in self.schema.unions:
self.collect_union_imports(union, imports)
# License header
lines.append(self.get_license_header("#"))
lines.append("")
lines.append("from __future__ import annotations")
lines.append("")
# Imports
for imp in sorted(imports):
lines.append(imp)
lines.append("")
lines.append("")
# Generate enums (top-level only)
for enum in self.schema.enums:
lines.extend(self.generate_enum(enum))
lines.append("")
lines.append("")
# Generate unions (top-level only)
for union in self.schema.unions:
lines.extend(self.generate_union(union))
lines.append("")
lines.append("")
# Generate messages (including nested types)
for message in self.schema.messages:
lines.extend(self.generate_message(message, indent=0))
lines.append("")
lines.append("")
# Generate registration function
lines.extend(self.generate_registration())
lines.append("")
return GeneratedFile(
path=f"{self.get_module_name()}.py",
content="\n".join(lines),
)
def collect_message_imports(self, message: Message, imports: Set[str]):
"""Collect imports for a message and its nested types recursively."""
for field in message.fields:
self.collect_field_imports(field, imports)
for nested_msg in message.nested_messages:
self.collect_message_imports(nested_msg, imports)
for nested_union in message.nested_unions:
self.collect_union_imports(nested_union, imports)
def collect_union_imports(self, union: Union, imports: Set[str]):
"""Collect imports for a union and its cases."""
imports.add("from pyfory.union import Union, UnionSerializer")
for field in union.fields:
self.collect_field_imports(field, imports)
def generate_enum(self, enum: Enum, indent: int = 0) -> List[str]:
"""Generate a Python IntEnum."""
lines = []
ind = " " * indent
lines.append(f"{ind}class {enum.name}(IntEnum):")
# Enum values (strip prefix for scoped enums)
for value in enum.values:
stripped_name = self.strip_enum_prefix(enum.name, value.name)
lines.append(f"{ind} {stripped_name} = {value.value}")
return lines
def generate_message(
self,
message: Message,
indent: int = 0,
parent_stack: Optional[List[Message]] = None,
) -> List[str]:
"""Generate a Python dataclass with nested types."""
lines = []
ind = " " * indent
lineage = (parent_stack or []) + [message]
lines.append(f"{ind}@dataclass")
lines.append(f"{ind}class {message.name}:")
# Generate nested enums first (they need to be defined before fields reference them)
for nested_enum in message.nested_enums:
for line in self.generate_enum(nested_enum, indent=indent + 1):
lines.append(line)
lines.append("")
# Generate nested unions
for nested_union in message.nested_unions:
for line in self.generate_union(
nested_union, indent=indent + 1, parent_stack=lineage
):
lines.append(line)
lines.append("")
# Generate nested messages
for nested_msg in message.nested_messages:
for line in self.generate_message(
nested_msg,
indent=indent + 1,
parent_stack=lineage,
):
lines.append(line)
lines.append("")
# Generate fields
if (
not message.fields
and not message.nested_enums
and not message.nested_unions
and not message.nested_messages
):
lines.append(f"{ind} pass")
return lines
for field in message.fields:
field_lines = self.generate_field(field, lineage)
for line in field_lines:
lines.append(f"{ind} {line}")
# If there are nested types but no fields, add pass to avoid empty class body issues
if not message.fields and (
message.nested_enums or message.nested_unions or message.nested_messages
):
lines.append(f"{ind} pass")
return lines
def generate_union(
self,
union: Union,
indent: int = 0,
parent_stack: Optional[List[Message]] = None,
) -> List[str]:
"""Generate a Python tagged union."""
lines: List[str] = []
ind = " " * indent
parent_path = ""
if parent_stack:
parent_path = ".".join([msg.name for msg in parent_stack])
case_enum = f"{union.name}Case"
case_enum_ref = f"{parent_path}.{case_enum}" if parent_path else case_enum
union_ref = f"{parent_path}.{union.name}" if parent_path else union.name
lines.append(f"{ind}class {case_enum}(Enum):")
for field in union.fields:
case_name = self.to_upper_snake_case(field.name)
lines.append(f"{ind} {case_name} = {field.number}")
lines.append("")
lines.append(f"{ind}class {union.name}(Union):")
lines.append(f'{ind} __slots__ = ("_case",)')
lines.append("")
lines.append(
f"{ind} def __init__(self, case: {case_enum_ref}, value: object) -> None:"
)
lines.append(f"{ind} super().__init__(case.value, value)")
lines.append(f"{ind} self._case = case")
lines.append(f"{ind} self._validate()")
lines.append("")
for field in union.fields:
method_name = self.safe_name(self.to_snake_case(field.name))
case_name = self.to_upper_snake_case(field.name)
case_type = self.get_union_case_type(field, parent_stack)
lines.append(f"{ind} @classmethod")
lines.append(
f'{ind} def {method_name}(cls, v: {case_type}) -> "{union_ref}":'
)
lines.append(f"{ind} return cls({case_enum_ref}.{case_name}, v)")
lines.append("")
lines.append(f"{ind} @classmethod")
lines.append(
f'{ind} def _from_case_id(cls, case_id: int, value: object) -> "{union_ref}":'
)
for field in union.fields:
case_name = self.to_upper_snake_case(field.name)
lines.append(
f"{ind} if case_id == {case_enum_ref}.{case_name}.value:"
)
lines.append(
f"{ind} return cls({case_enum_ref}.{case_name}, value)"
)
lines.append(
f'{ind} raise ValueError("unknown {union.name} case id: {{}}".format(case_id))'
)
lines.append("")
lines.append(f"{ind} def _validate(self) -> None:")
has_checks = False
for field in union.fields:
case_name = self.to_upper_snake_case(field.name)
case_type = self.get_union_case_type(field, parent_stack)
check_expr = self.get_union_case_check(field, parent_stack)
if check_expr:
has_checks = True
lines.append(
f"{ind} if self._case == {case_enum_ref}.{case_name} and not {check_expr}:"
)
safe_case = self.safe_name(self.to_snake_case(field.name))
lines.append(
f'{ind} raise TypeError("{union.name}.{safe_case}(...) requires {case_type}")'
)
if not union.fields or not has_checks:
lines.append(f"{ind} pass")
lines.append("")
lines.append(f"{ind} def case(self) -> {case_enum_ref}:")
lines.append(f"{ind} return self._case")
lines.append("")
lines.append(f"{ind} def case_id(self) -> int:")
lines.append(f"{ind} return self._case_id")
lines.append("")
lines.append(f"{ind} def __eq__(self, other: object) -> bool:")
lines.append(f"{ind} if not isinstance(other, {union_ref}):")
lines.append(f"{ind} return NotImplemented")
lines.append(
f"{ind} return self._case == other._case and self._value == other._value"
)
lines.append("")
for field in union.fields:
case_name = self.to_upper_snake_case(field.name)
method_name = self.safe_name(self.to_snake_case(field.name))
case_type = self.get_union_case_type(field, parent_stack)
lines.append(f"{ind} def is_{method_name}(self) -> bool:")
lines.append(
f"{ind} return self._case == {case_enum_ref}.{case_name}"
)
lines.append("")
lines.append(f"{ind} def {method_name}_value(self) -> {case_type}:")
lines.append(f"{ind} if self._case != {case_enum_ref}.{case_name}:")
lines.append(
f'{ind} raise ValueError("{union.name} is not {case_name.lower()}")'
)
lines.append(f"{ind} return cast({case_type}, self._value)")
lines.append("")
lines.append(
f"{ind} def set_{method_name}(self, v: {case_type}) -> None:"
)
lines.append(f"{ind} self._case = {case_enum_ref}.{case_name}")
lines.append(
f"{ind} self._case_id = {case_enum_ref}.{case_name}.value"
)
lines.append(f"{ind} self._value = v")
lines.append(f"{ind} self._validate()")
lines.append("")
lines.extend(self.generate_union_serializer(union, indent, parent_stack))
return lines
def generate_union_serializer(
self,
union: Union,
indent: int = 0,
parent_stack: Optional[List[Message]] = None,
) -> List[str]:
"""Generate a Python serializer for a union."""
lines: List[str] = []
ind = " " * indent
serializer_name = f"{union.name}Serializer"
parent_path = ""
if parent_stack:
parent_path = ".".join([msg.name for msg in parent_stack])
union_ref = f"{parent_path}.{union.name}" if parent_path else union.name
lines.append(f"{ind}class {serializer_name}(UnionSerializer):")
lines.append("")
lines.append(f"{ind} def __init__(self, fory: pyfory.Fory):")
lines.append(f"{ind} super().__init__(fory, {union_ref}, {{")
for field in union.fields:
case_type = self.get_union_case_type(field, parent_stack)
lines.append(f"{ind} {field.number}: {case_type},")
lines.append(f"{ind} }})")
lines.append("")
return lines
def generate_field(
self,
field: Field,
parent_stack: Optional[List[Message]] = None,
) -> List[str]:
"""Generate a dataclass field."""
lines = []
is_any = (
isinstance(field.field_type, PrimitiveType)
and field.field_type.kind == PrimitiveKind.ANY
)
nullable = field.optional or is_any
python_type = self.generate_type(
field.field_type,
nullable,
field.element_optional,
field.element_ref,
parent_stack,
)
field_name = self.safe_name(self.to_snake_case(field.name))
default_factory = self.get_default_factory(field)
default = self.get_default_value(field.field_type, field.optional)
default_expr = default
trailing_comment = ""
if " # " in default:
default_expr, comment = default.split(" # ", 1)
trailing_comment = f" # {comment}"
tag_id = field.tag_id
if tag_id is not None or field.ref or nullable:
field_args = []
if tag_id is not None:
field_args.append(f"id={tag_id}")
if nullable:
field_args.append("nullable=True")
if field.ref:
field_args.append("ref=True")
if default_factory is not None:
field_args.append(f"default_factory={default_factory}")
else:
field_args.append(f"default={default_expr}")
field_default = f"pyfory.field({', '.join(field_args)}){trailing_comment}"
else:
if default_factory is not None:
field_default = f"field(default_factory={default_factory})"
else:
field_default = f"{default_expr}{trailing_comment}"
lines.append(f"{field_name}: {python_type} = {field_default}")
return lines
def uses_numpy_array(self, field_type: ListType, element_optional: bool) -> bool:
"""Return True if a list should be represented as a numpy array."""
if not isinstance(field_type.element_type, PrimitiveType):
return False
return (
field_type.element_type.kind in self.ARRAY_TYPE_HINTS
and not element_optional
)
def get_default_factory(self, field: Field) -> Optional[str]:
"""Get default factory name for list/map fields."""
if field.optional:
return None
if isinstance(field.field_type, ListType):
if self.uses_numpy_array(field.field_type, field.element_optional):
return None
return "list"
if isinstance(field.field_type, MapType):
return "dict"
return None
def generate_type(
self,
field_type: FieldType,
nullable: bool = False,
element_optional: bool = False,
element_ref: bool = False,
parent_stack: Optional[List[Message]] = None,
) -> str:
"""Generate Python type hint."""
if isinstance(field_type, PrimitiveType):
if field_type.kind == PrimitiveKind.ANY:
return "Any"
base_type = self.PRIMITIVE_MAP[field_type.kind]
if nullable:
return f"Optional[{base_type}]"
return base_type
elif isinstance(field_type, NamedType):
type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
if nullable:
return f"Optional[{type_name}]"
return type_name
elif isinstance(field_type, ListType):
# Use numpy array for numeric primitive types
if isinstance(field_type.element_type, PrimitiveType):
if not element_optional:
kind = field_type.element_type.kind
if kind in self.ARRAY_TYPE_HINTS:
list_type = self.ARRAY_TYPE_HINTS[kind]
else:
element_type = self.generate_type(
field_type.element_type,
False,
False,
False,
parent_stack,
)
element_type = self.wrap_ref_type(
field_type.element_type,
element_type,
element_ref=element_ref,
)
if element_optional:
element_type = f"Optional[{element_type}]"
list_type = f"List[{element_type}]"
else:
element_type = self.generate_type(
field_type.element_type,
False,
False,
False,
parent_stack,
)
element_type = self.wrap_ref_type(
field_type.element_type,
element_type,
element_ref=element_ref,
)
if element_optional:
element_type = f"Optional[{element_type}]"
list_type = f"List[{element_type}]"
else:
element_type = self.generate_type(
field_type.element_type,
False,
False,
False,
parent_stack,
)
element_type = self.wrap_ref_type(
field_type.element_type,
element_type,
element_ref=element_ref,
)
if element_optional:
element_type = f"Optional[{element_type}]"
list_type = f"List[{element_type}]"
if nullable:
return f"Optional[{list_type}]"
return list_type
elif isinstance(field_type, MapType):
key_type = self.generate_type(
field_type.key_type, False, False, False, parent_stack
)
value_type = self.generate_type(
field_type.value_type, False, False, False, parent_stack
)
value_type = self.wrap_ref_type(
field_type.value_type,
value_type,
element_ref=field_type.value_ref,
)
map_type = f"Dict[{key_type}, {value_type}]"
if nullable:
return f"Optional[{map_type}]"
return map_type
return "object"
def schema_has_ref_elements(self) -> bool:
for message in self.schema.messages:
if self.message_has_ref_elements(message):
return True
for union in self.schema.unions:
for field in union.fields:
if self.field_uses_ref_type(field):
return True
return False
def message_has_ref_elements(self, message: Message) -> bool:
for field in message.fields:
if self.field_uses_ref_type(field):
return True
for nested_msg in message.nested_messages:
if self.message_has_ref_elements(nested_msg):
return True
for nested_union in message.nested_unions:
for field in nested_union.fields:
if self.field_uses_ref_type(field):
return True
return False
def field_uses_ref_type(self, field: Field) -> bool:
if isinstance(field.field_type, ListType):
return self.is_ref_target_type(field.field_type.element_type)
if isinstance(field.field_type, MapType):
return self.is_ref_target_type(field.field_type.value_type)
return False
def is_ref_target_type(self, field_type: FieldType) -> bool:
if not isinstance(field_type, NamedType):
return False
resolved = self.schema.get_type(field_type.name)
return isinstance(resolved, (Message, Union))
def wrap_ref_type(
self,
field_type: FieldType,
element_type: str,
element_ref: bool,
) -> str:
if not self.is_ref_target_type(field_type):
return element_type
if element_ref:
return f"Ref[{element_type}]"
return f"Ref[{element_type}, False]"
def get_union_case_type(
self, field: Field, parent_stack: Optional[List[Message]] = None
) -> str:
"""Return the Python type name for a union case."""
return self.generate_type(
field.field_type,
nullable=False,
element_optional=field.element_optional,
element_ref=field.element_ref,
parent_stack=parent_stack,
)
def get_union_case_check(
self, field: Field, parent_stack: Optional[List[Message]] = None
) -> Optional[str]:
"""Return an isinstance expression to validate a union case value."""
return self.get_union_case_runtime_check(field, parent_stack, "self._value")
def get_union_case_runtime_check(
self,
field: Field,
parent_stack: Optional[List[Message]],
value_expr: str,
) -> Optional[str]:
"""Return an isinstance expression for a union case value expression."""
if isinstance(field.field_type, PrimitiveType):
base = self.PRIMITIVE_MAP[field.field_type.kind]
if base.startswith("pyfory."):
if "float" in base:
return f"isinstance({value_expr}, float)"
return f"isinstance({value_expr}, int)"
if base == "bool":
return f"isinstance({value_expr}, bool)"
if base == "str":
return f"isinstance({value_expr}, str)"
if base == "bytes":
return f"isinstance({value_expr}, (bytes, bytearray))"
if base == "datetime.date":
return f"isinstance({value_expr}, datetime.date)"
if base == "datetime.datetime":
return f"isinstance({value_expr}, datetime.datetime)"
if isinstance(field.field_type, NamedType):
type_name = self.resolve_nested_type_name(
field.field_type.name, parent_stack
)
return f"isinstance({value_expr}, {type_name})"
return None
def resolve_nested_type_name(
self,
type_name: str,
parent_stack: Optional[List[Message]] = None,
) -> str:
"""Resolve nested type names to fully-qualified references."""
if "." in type_name or not parent_stack:
return type_name
for i in range(len(parent_stack) - 1, -1, -1):
message = parent_stack[i]
if message.get_nested_type(type_name) is not None:
prefix = ".".join(parent.name for parent in parent_stack[: i + 1])
return f"{prefix}.{type_name}"
return type_name
def get_default_value(self, field_type: FieldType, nullable: bool = False) -> str:
"""Get default value for a field."""
if nullable:
return "None"
if isinstance(field_type, PrimitiveType):
return self.DEFAULT_VALUES.get(field_type.kind, "None")
elif isinstance(field_type, NamedType):
return "None"
elif isinstance(field_type, ListType):
# Use numpy empty array for numeric types
if isinstance(field_type.element_type, PrimitiveType):
if field_type.element_type.kind in self.NUMPY_DTYPE_MAP:
dtype = self.NUMPY_DTYPE_MAP[field_type.element_type.kind]
return f"None # Use np.array([], dtype={dtype}) to initialize"
return "None"
elif isinstance(field_type, MapType):
return "None"
return "None"
def collect_imports(
self,
field_type: FieldType,
imports: Set[str],
element_optional: bool = False,
):
"""Collect required imports for a field type."""
if isinstance(field_type, PrimitiveType):
if field_type.kind in (PrimitiveKind.DATE, PrimitiveKind.TIMESTAMP):
imports.add("import datetime")
elif field_type.kind == PrimitiveKind.ANY:
imports.add("from typing import Any")
elif isinstance(field_type, ListType):
# Add numpy import for primitive arrays
if isinstance(field_type.element_type, PrimitiveType):
if (
field_type.element_type.kind in self.ARRAY_TYPE_HINTS
and not element_optional
):
imports.add("import numpy as np")
return
self.collect_imports(field_type.element_type, imports)
elif isinstance(field_type, MapType):
self.collect_imports(field_type.key_type, imports)
self.collect_imports(field_type.value_type, imports)
def collect_field_imports(self, field: Field, imports: Set[str]):
"""Collect imports for a field, including list modifiers."""
self.collect_imports(field.field_type, imports, field.element_optional)
def generate_registration(self) -> List[str]:
"""Generate the Fory registration function."""
lines = []
func_name = f"register_{self.get_module_name()}_types"
lines.append(f"def {func_name}(fory: pyfory.Fory):")
if (
not self.schema.enums
and not self.schema.messages
and not self.schema.unions
):
lines.append(" pass")
return lines
# Register enums (top-level)
for enum in self.schema.enums:
self.generate_enum_registration(lines, enum, "")
# Register unions (top-level)
for union in self.schema.unions:
self.generate_union_registration(lines, union, "")
# Register messages (including nested types)
for message in self.schema.messages:
self.generate_message_registration(lines, message, "")
return lines
def generate_enum_registration(
self, lines: List[str], enum: Enum, parent_path: str
):
"""Generate registration code for an enum."""
# In Python, nested class references use Outer.Inner syntax
class_ref = f"{parent_path}.{enum.name}" if parent_path else enum.name
type_name = class_ref if parent_path else enum.name
if enum.type_id is not None:
lines.append(f" fory.register_type({class_ref}, type_id={enum.type_id})")
else:
ns = self.package or "default"
lines.append(
f' fory.register_type({class_ref}, namespace="{ns}", typename="{type_name}")'
)
def generate_message_registration(
self, lines: List[str], message: Message, parent_path: str
):
"""Generate registration code for a message and its nested types."""
# In Python, nested class references use Outer.Inner syntax
class_ref = f"{parent_path}.{message.name}" if parent_path else message.name
type_name = class_ref if parent_path else message.name
if message.type_id is not None:
lines.append(
f" fory.register_type({class_ref}, type_id={message.type_id})"
)
else:
ns = self.package or "default"
lines.append(
f' fory.register_type({class_ref}, namespace="{ns}", typename="{type_name}")'
)
# Register nested enums
for nested_enum in message.nested_enums:
self.generate_enum_registration(lines, nested_enum, class_ref)
# Register nested unions
for nested_union in message.nested_unions:
self.generate_union_registration(lines, nested_union, class_ref)
# Register nested messages
for nested_msg in message.nested_messages:
self.generate_message_registration(lines, nested_msg, class_ref)
def generate_union_registration(
self, lines: List[str], union: Union, parent_path: str
):
"""Generate registration code for a union."""
class_ref = f"{parent_path}.{union.name}" if parent_path else union.name
type_name = class_ref if parent_path else union.name
serializer_ref = (
f"{parent_path}.{union.name}Serializer"
if parent_path
else f"{union.name}Serializer"
)
if union.type_id is not None:
lines.append(
f" fory.register_union({class_ref}, type_id={union.type_id}, serializer={serializer_ref}(fory))"
)
else:
ns = self.package or "default"
lines.append(
f' fory.register_union({class_ref}, namespace="{ns}", typename="{type_name}", serializer={serializer_ref}(fory))'
)