blob: 74637151528d44b8413fdb1668122bc366cf71c2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""C++ code generator."""
from typing import Dict, List, Optional, Set
import typing
from fory_compiler.generators.base import BaseGenerator, GeneratedFile
from fory_compiler.ir.ast import (
Message,
Enum,
Union,
Field,
FieldType,
PrimitiveType,
NamedType,
ListType,
MapType,
)
from fory_compiler.ir.types import PrimitiveKind
class CppGenerator(BaseGenerator):
"""Generates C++ classes with FORY_STRUCT macros."""
language_name = "cpp"
file_extension = ".h"
# Mapping from FDL primitive types to C++ types
PRIMITIVE_MAP = {
PrimitiveKind.BOOL: "bool",
PrimitiveKind.INT8: "int8_t",
PrimitiveKind.INT16: "int16_t",
PrimitiveKind.INT32: "int32_t",
PrimitiveKind.VARINT32: "int32_t",
PrimitiveKind.INT64: "int64_t",
PrimitiveKind.VARINT64: "int64_t",
PrimitiveKind.TAGGED_INT64: "int64_t",
PrimitiveKind.UINT8: "uint8_t",
PrimitiveKind.UINT16: "uint16_t",
PrimitiveKind.UINT32: "uint32_t",
PrimitiveKind.VAR_UINT32: "uint32_t",
PrimitiveKind.UINT64: "uint64_t",
PrimitiveKind.VAR_UINT64: "uint64_t",
PrimitiveKind.TAGGED_UINT64: "uint64_t",
PrimitiveKind.FLOAT16: "float",
PrimitiveKind.FLOAT32: "float",
PrimitiveKind.FLOAT64: "double",
PrimitiveKind.STRING: "std::string",
PrimitiveKind.BYTES: "std::vector<uint8_t>",
PrimitiveKind.DATE: "fory::serialization::Date",
PrimitiveKind.TIMESTAMP: "fory::serialization::Timestamp",
PrimitiveKind.ANY: "std::any",
}
NUMERIC_PRIMITIVES = {
PrimitiveKind.BOOL,
PrimitiveKind.INT8,
PrimitiveKind.INT16,
PrimitiveKind.INT32,
PrimitiveKind.VARINT32,
PrimitiveKind.INT64,
PrimitiveKind.VARINT64,
PrimitiveKind.TAGGED_INT64,
PrimitiveKind.UINT8,
PrimitiveKind.UINT16,
PrimitiveKind.UINT32,
PrimitiveKind.VAR_UINT32,
PrimitiveKind.UINT64,
PrimitiveKind.VAR_UINT64,
PrimitiveKind.TAGGED_UINT64,
PrimitiveKind.FLOAT16,
PrimitiveKind.FLOAT32,
PrimitiveKind.FLOAT64,
}
def generate(self) -> List[GeneratedFile]:
"""Generate C++ files for the schema."""
files = []
# Generate a single header file with all types
files.append(self.generate_header())
return files
def get_header_name(self) -> str:
"""Get the header file name."""
if self.package:
return self.package.replace(".", "_")
return "generated"
def get_namespace(self) -> str:
"""Get the C++ namespace."""
if self.package:
return self.package.replace(".", "::")
return ""
def get_namespaced_type_name(
self,
type_name: str,
parent_stack: List[Message],
) -> str:
"""Get a C++ type name including namespace for global macros."""
qualified_name = self.get_qualified_type_name(type_name, parent_stack)
namespace = self.get_namespace()
if namespace:
return f"{namespace}::{qualified_name}"
return qualified_name
def get_field_config_type(
self,
type_name: str,
parent_stack: List[Message],
) -> str:
"""Get type name for FORY_FIELD_CONFIG."""
return self.get_namespaced_type_name(type_name, parent_stack)
def generate_header(self) -> GeneratedFile:
"""Generate a C++ header file with all types."""
lines = []
includes: Set[str] = set()
enum_macros: List[str] = []
union_macros: List[str] = []
field_config_macros: List[str] = []
definition_items = self.get_definition_order()
# Collect includes (including from nested types)
includes.add("<cstdint>")
includes.add("<map>")
includes.add("<memory>")
includes.add("<string>")
includes.add("<vector>")
includes.add("<utility>")
includes.add('"fory/serialization/fory.h"')
if self.schema_has_unions():
includes.add("<utility>")
includes.add("<variant>")
includes.add("<memory>")
includes.add("<typeindex>")
includes.add('"fory/serialization/union_serializer.h"')
for message in self.schema.messages:
self.collect_message_includes(message, includes)
for union in self.schema.unions:
self.collect_union_includes(union, includes)
# License header
lines.append("/*")
for line in self.get_license_header(" *").split("\n"):
lines.append(line)
lines.append(" */")
lines.append("")
# Header guard
guard_name = f"{self.get_header_name().upper()}_H_"
lines.append(f"#ifndef {guard_name}")
lines.append(f"#define {guard_name}")
lines.append("")
# Includes
for inc in sorted(includes):
lines.append(f"#include {inc}")
lines.append("")
# Namespace
namespace = self.get_namespace()
if namespace:
lines.append(f"namespace {namespace} {{")
lines.append("")
# Forward declarations (top-level messages only)
self.generate_forward_declarations(lines)
if self.schema.messages:
lines.append("")
# Generate enums (top-level)
for enum in self.schema.enums:
lines.extend(self.generate_enum_definition(enum))
enum_macros.append(self.generate_enum_macro(enum, []))
lines.append("")
# Generate top-level unions/messages in dependency order
for kind, item in definition_items:
if kind == "union":
lines.extend(self.generate_union_definition(item, [], ""))
union_macros.extend(self.generate_union_macros(item, []))
lines.append("")
continue
lines.extend(
self.generate_message_definition(
item,
[],
enum_macros,
union_macros,
field_config_macros,
"",
)
)
lines.append("")
if union_macros:
lines.extend(union_macros)
lines.append("")
if field_config_macros:
lines.extend(field_config_macros)
lines.append("")
if enum_macros:
lines.extend(enum_macros)
lines.append("")
# Generate registration function (after FORY_STRUCT/FORY_ENUM)
lines.extend(self.generate_registration())
lines.append("")
if namespace:
lines.append(f"}} // namespace {namespace}")
lines.append("")
# End header guard
lines.append(f"#endif // {guard_name}")
lines.append("")
return GeneratedFile(
path=f"{self.get_header_name()}.h",
content="\n".join(lines),
)
def collect_message_includes(self, message: Message, includes: Set[str]):
"""Collect includes for a message and its nested types recursively."""
for field in message.fields:
weak_ref = self.get_field_weak_ref(field)
element_weak_ref = self.get_element_weak_ref(field)
if isinstance(field.field_type, MapType):
weak_ref = self.get_map_value_weak_ref(field)
self.collect_includes(
field.field_type,
field.optional,
field.ref,
includes,
field.element_optional,
field.element_ref,
weak_ref,
element_weak_ref,
)
for nested_msg in message.nested_messages:
self.collect_message_includes(nested_msg, includes)
for nested_union in message.nested_unions:
self.collect_union_includes(nested_union, includes)
def collect_union_includes(self, union: Union, includes: Set[str]):
"""Collect includes for a union and its cases."""
for field in union.fields:
self.collect_includes(
field.field_type,
False,
False,
includes,
field.element_optional,
field.element_ref,
False,
False,
)
def generate_forward_declarations(self, lines: List[str]):
"""Generate forward declarations for top-level messages."""
for message in self.schema.messages:
lines.append(f"class {message.name};")
def get_definition_order(self) -> List:
"""Return top-level unions/messages in dependency order."""
items: List = []
for union in self.schema.unions:
items.append(("union", union))
for message in self.schema.messages:
items.append(("message", message))
name_to_index = {}
for idx, (kind, item) in enumerate(items):
name_to_index[item.name] = idx
dependencies: Dict[int, Set[int]] = {i: set() for i in range(len(items))}
reverse_edges: Dict[int, Set[int]] = {i: set() for i in range(len(items))}
for idx, (kind, item) in enumerate(items):
deps: Set[str] = set()
if kind == "union":
self.collect_union_dependencies(item, [], deps)
else:
self.collect_message_dependencies(item, [], deps)
for dep_name in deps:
dep_idx = name_to_index.get(dep_name)
if dep_idx is None or dep_idx == idx:
continue
dependencies[idx].add(dep_idx)
reverse_edges[dep_idx].add(idx)
in_degree = {idx: len(dependencies[idx]) for idx in dependencies}
available = [idx for idx, degree in in_degree.items() if degree == 0]
ordered: List = []
while available:
available.sort()
idx = available.pop(0)
ordered.append(items[idx])
for neighbor in reverse_edges[idx]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
available.append(neighbor)
if len(ordered) != len(items):
raise ValueError("C++ generator cannot resolve type order for unions.")
return ordered
def collect_message_dependencies(
self, message: Message, parent_stack: List[Message], deps: Set[str]
) -> None:
"""Collect top-level type dependencies for a message."""
lineage = parent_stack + [message]
for field in message.fields:
if field.ref or field.element_ref:
continue
if isinstance(field.field_type, MapType) and field.field_type.value_ref:
continue
self.collect_type_dependencies(field.field_type, lineage, deps)
for nested_union in message.nested_unions:
self.collect_union_dependencies(nested_union, lineage, deps)
for nested_msg in message.nested_messages:
self.collect_message_dependencies(nested_msg, lineage, deps)
def collect_union_dependencies(
self, union: Union, parent_stack: List[Message], deps: Set[str]
) -> None:
"""Collect top-level type dependencies for a union."""
for field in union.fields:
self.collect_type_dependencies(field.field_type, parent_stack, deps)
def collect_type_dependencies(
self, field_type: FieldType, parent_stack: List[Message], deps: Set[str]
) -> None:
if isinstance(field_type, PrimitiveType):
return
if isinstance(field_type, NamedType):
type_name = field_type.name
if self.is_nested_type_reference(type_name, parent_stack):
return
top_level = type_name.split(".")[0]
deps.add(top_level)
return
if isinstance(field_type, ListType):
self.collect_type_dependencies(field_type.element_type, parent_stack, deps)
return
if isinstance(field_type, MapType):
self.collect_type_dependencies(field_type.key_type, parent_stack, deps)
self.collect_type_dependencies(field_type.value_type, parent_stack, deps)
return
def is_nested_type_reference(
self, type_name: str, parent_stack: List[Message]
) -> bool:
if not parent_stack:
return False
root_name = parent_stack[0].name
if "." in type_name:
return type_name.split(".")[0] == root_name
for i in range(len(parent_stack) - 1, -1, -1):
message = parent_stack[i]
if message.get_nested_type(type_name) is not None:
return True
return False
def get_enum_value_names(self, enum: Enum) -> List[str]:
"""Get enum value names without the enum prefix."""
stripped_names = []
for value in enum.values:
stripped_names.append(self.strip_enum_prefix(enum.name, value.name))
return stripped_names
def generate_enum_definition(self, enum: Enum, indent: str = "") -> List[str]:
"""Generate a C++ enum class definition."""
lines = []
lines.append(f"{indent}enum class {enum.name} : int32_t {{")
for value in enum.values:
stripped_name = self.strip_enum_prefix(enum.name, value.name)
lines.append(f"{indent} {stripped_name} = {value.value},")
lines.append(f"{indent}}};")
return lines
def generate_enum_macro(
self,
enum: Enum,
parent_stack: List[Message],
) -> str:
"""Generate a FORY_ENUM macro line for an enum."""
value_names = ", ".join(self.get_enum_value_names(enum))
qualified_name = self.get_namespaced_type_name(enum.name, parent_stack)
return f"FORY_ENUM({qualified_name}, {value_names});"
def resolve_named_type(
self, name: str, parent_stack: Optional[List[Message]]
) -> Optional[typing.Union[Message, Enum, Union]]:
"""Resolve a named type to its schema definition."""
parts = name.split(".")
if len(parts) > 1:
current = self.find_top_level_type(parts[0])
for part in parts[1:]:
if isinstance(current, Message):
current = current.get_nested_type(part)
else:
return None
return current
if parent_stack:
for msg in reversed(parent_stack):
nested = msg.get_nested_type(name)
if nested is not None:
return nested
return self.find_top_level_type(name)
def find_top_level_type(
self, name: str
) -> Optional[typing.Union[Message, Enum, Union]]:
"""Find a top-level type definition by name."""
for enum in self.schema.enums:
if enum.name == name:
return enum
for union in self.schema.unions:
if union.name == name:
return union
for message in self.schema.messages:
if message.name == name:
return message
return None
def is_message_type(
self, field_type: FieldType, parent_stack: Optional[List[Message]]
) -> bool:
if not isinstance(field_type, NamedType):
return False
resolved = self.resolve_named_type(field_type.name, parent_stack)
return isinstance(resolved, Message)
def is_weak_ref(self, options: dict) -> bool:
return options.get("weak_ref") is True
def get_field_weak_ref(self, field: Field) -> bool:
return self.is_weak_ref(field.ref_options)
def get_element_weak_ref(self, field: Field) -> bool:
return self.is_weak_ref(field.element_ref_options)
def get_map_value_weak_ref(self, field: Field) -> bool:
if isinstance(field.field_type, MapType):
return self.is_weak_ref(field.field_type.value_ref_options)
return False
def is_union_type(
self, field_type: FieldType, parent_stack: Optional[List[Message]]
) -> bool:
if not isinstance(field_type, NamedType):
return False
resolved = self.resolve_named_type(field_type.name, parent_stack)
return isinstance(resolved, Union)
def is_enum_type(
self, field_type: FieldType, parent_stack: Optional[List[Message]]
) -> bool:
if not isinstance(field_type, NamedType):
return False
resolved = self.resolve_named_type(field_type.name, parent_stack)
return isinstance(resolved, Enum)
def get_field_member_name(self, field: Field) -> str:
return f"{self.to_snake_case(field.name)}_"
def get_field_storage_type(
self, field: Field, parent_stack: Optional[List[Message]]
) -> str:
weak_ref = self.get_field_weak_ref(field)
element_weak_ref = self.get_element_weak_ref(field)
if isinstance(field.field_type, MapType):
weak_ref = self.get_map_value_weak_ref(field)
if self.is_message_type(field.field_type, parent_stack) and not (
field.ref or weak_ref
):
type_name = self.resolve_nested_type_name(
field.field_type.name, parent_stack
)
return f"std::unique_ptr<{type_name}>"
return self.generate_type(
field.field_type,
False
if (
self.is_message_type(field.field_type, parent_stack)
and (field.ref or weak_ref)
)
else field.optional,
field.ref,
field.element_optional,
field.element_ref,
weak_ref,
element_weak_ref,
parent_stack,
)
def get_field_value_type(
self, field: Field, parent_stack: Optional[List[Message]]
) -> str:
weak_ref = self.get_field_weak_ref(field)
element_weak_ref = self.get_element_weak_ref(field)
if isinstance(field.field_type, MapType):
weak_ref = self.get_map_value_weak_ref(field)
if self.is_message_type(field.field_type, parent_stack) and not (
field.ref or weak_ref
):
return self.resolve_nested_type_name(field.field_type.name, parent_stack)
return self.generate_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
weak_ref,
element_weak_ref,
parent_stack,
)
def get_field_eq_expression(
self, field: Field, parent_stack: Optional[List[Message]]
) -> str:
member_name = self.get_field_member_name(field)
other_member = f"other.{member_name}"
if isinstance(field.field_type, PrimitiveType) and (
field.field_type.kind == PrimitiveKind.ANY
):
return (
f"((!{member_name}.has_value() && !{other_member}.has_value()) || "
f"({member_name}.type() == {other_member}.type()))"
)
if self.is_message_type(
field.field_type, parent_stack
) and self.get_field_weak_ref(field):
return (
f"(([&]() {{ auto left = {member_name}.upgrade(); "
f"auto right = {other_member}.upgrade(); "
f"return left == right; }})())"
)
if self.is_message_type(field.field_type, parent_stack) and field.ref:
return f"{member_name} == {other_member}"
if self.is_message_type(field.field_type, parent_stack):
return (
f"(({member_name} && {other_member}) ? "
f"(*{member_name} == *{other_member}) : "
f"({member_name} == {other_member}))"
)
return f"{member_name} == {other_member}"
def is_numeric_field(self, field: Field) -> bool:
if not isinstance(field.field_type, PrimitiveType):
return False
return field.field_type.kind in self.NUMERIC_PRIMITIVES
def is_string_field(self, field: Field) -> bool:
return isinstance(field.field_type, PrimitiveType) and (
field.field_type.kind == PrimitiveKind.STRING
)
def is_bytes_field(self, field: Field) -> bool:
return isinstance(field.field_type, PrimitiveType) and (
field.field_type.kind == PrimitiveKind.BYTES
)
def generate_field_accessors(
self, field: Field, parent_stack: Optional[List[Message]], indent: str
) -> List[str]:
lines: List[str] = []
field_name = self.to_snake_case(field.name)
member_name = self.get_field_member_name(field)
value_type = self.get_field_value_type(field, parent_stack)
weak_ref = self.get_field_weak_ref(field)
is_union = self.is_union_type(field.field_type, parent_stack)
is_enum = self.is_enum_type(field.field_type, parent_stack)
is_collection = isinstance(field.field_type, (ListType, MapType))
is_bytes = self.is_bytes_field(field)
is_string = self.is_string_field(field)
needs_mutable = is_string or is_collection or is_bytes or is_union
no_setter = is_collection or is_bytes or is_union
value_getter = self.is_numeric_field(field) or is_enum
if self.is_message_type(field.field_type, parent_stack) and (
field.ref or weak_ref
):
lines.append(f"{indent}bool has_{field_name}() const {{")
if weak_ref:
lines.append(f"{indent} return !{member_name}.expired();")
else:
lines.append(f"{indent} return {member_name} != nullptr;")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}const {value_type}& {field_name}() const {{")
lines.append(f"{indent} return {member_name};")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}{value_type}* mutable_{field_name}() {{")
lines.append(f"{indent} return &{member_name};")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}void set_{field_name}({value_type} value) {{")
lines.append(f"{indent} {member_name} = std::move(value);")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}void clear_{field_name}() {{")
lines.append(f"{indent} {member_name} = {value_type}();")
lines.append(f"{indent}}}")
return lines
if self.is_message_type(field.field_type, parent_stack):
lines.append(f"{indent}bool has_{field_name}() const {{")
lines.append(f"{indent} return {member_name} != nullptr;")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}const {value_type}& {field_name}() const {{")
lines.append(f"{indent} return *{member_name};")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}{value_type}* mutable_{field_name}() {{")
lines.append(f"{indent} if (!{member_name}) {{")
lines.append(
f"{indent} {member_name} = std::make_unique<{value_type}>();"
)
lines.append(f"{indent} }}")
lines.append(f"{indent} return {member_name}.get();")
lines.append(f"{indent}}}")
lines.append("")
lines.append(f"{indent}void clear_{field_name}() {{")
lines.append(f"{indent} {member_name}.reset();")
lines.append(f"{indent}}}")
return lines
if field.optional:
lines.append(f"{indent}bool has_{field_name}() const {{")
lines.append(f"{indent} return {member_name}.has_value();")
lines.append(f"{indent}}}")
lines.append("")
if value_getter:
lines.append(f"{indent}{value_type} {field_name}() const {{")
else:
lines.append(f"{indent}const {value_type}& {field_name}() const {{")
if field.optional:
lines.append(f"{indent} return *{member_name};")
else:
lines.append(f"{indent} return {member_name};")
lines.append(f"{indent}}}")
if needs_mutable:
lines.append("")
lines.append(f"{indent}{value_type}* mutable_{field_name}() {{")
if field.optional:
lines.append(f"{indent} if (!{member_name}) {{")
lines.append(f"{indent} {member_name}.emplace();")
lines.append(f"{indent} }}")
lines.append(f"{indent} return &{member_name}.value();")
else:
lines.append(f"{indent} return &{member_name};")
lines.append(f"{indent}}}")
if not no_setter:
lines.append("")
if is_string:
lines.append(f"{indent}template <class Arg, class... Args>")
lines.append(
f"{indent}void set_{field_name}(Arg&& arg, Args&&... args) {{"
)
if field.optional:
lines.append(
f"{indent} {member_name}.emplace(std::forward<Arg>(arg), std::forward<Args>(args)...);"
)
else:
lines.append(
f"{indent} {member_name} = {value_type}(std::forward<Arg>(arg), std::forward<Args>(args)...);"
)
lines.append(f"{indent}}}")
else:
lines.append(f"{indent}void set_{field_name}({value_type} value) {{")
lines.append(f"{indent} {member_name} = std::move(value);")
lines.append(f"{indent}}}")
if field.optional:
lines.append("")
lines.append(f"{indent}void clear_{field_name}() {{")
lines.append(f"{indent} {member_name}.reset();")
lines.append(f"{indent}}}")
return lines
def generate_message_definition(
self,
message: Message,
parent_stack: List[Message],
enum_macros: List[str],
union_macros: List[str],
field_config_macros: List[str],
indent: str,
) -> List[str]:
"""Generate a C++ class definition with nested types."""
lines: List[str] = []
class_name = message.name
lineage = parent_stack + [message]
body_indent = f"{indent} "
field_indent = f"{indent} "
lines.append(f"{indent}class {class_name} final {{")
lines.append(f"{body_indent}public:")
if message.fields:
lines.append("")
for nested_enum in message.nested_enums:
lines.extend(self.generate_enum_definition(nested_enum, body_indent))
lines.append("")
enum_macros.append(self.generate_enum_macro(nested_enum, lineage))
for nested_msg in message.nested_messages:
lines.extend(
self.generate_message_definition(
nested_msg,
lineage,
enum_macros,
union_macros,
field_config_macros,
body_indent,
)
)
lines.append("")
for nested_union in message.nested_unions:
lines.extend(
self.generate_union_definition(
nested_union,
lineage,
body_indent,
)
)
union_macros.extend(self.generate_union_macros(nested_union, lineage))
lines.append("")
if message.fields:
for index, field in enumerate(message.fields):
lines.extend(self.generate_field_accessors(field, lineage, body_indent))
if index + 1 < len(message.fields):
lines.append("")
lines.append("")
lines.append(
f"{body_indent}bool operator==(const {class_name}& other) const {{"
)
if message.fields:
conditions = [
self.get_field_eq_expression(field, lineage) for field in message.fields
]
lines.append(f"{body_indent} return {' && '.join(conditions)};")
else:
lines.append(f"{body_indent} return true;")
lines.append(f"{body_indent}}}")
struct_type_name = self.get_qualified_type_name(message.name, parent_stack)
if message.fields:
lines.append("")
lines.append(f"{body_indent}private:")
for field in message.fields:
field_type = self.get_field_storage_type(field, lineage)
member_name = self.get_field_member_name(field)
lines.append(f"{field_indent}{field_type} {member_name};")
lines.append("")
lines.append(f"{body_indent}public:")
field_members = ", ".join(
self.get_field_member_name(f) for f in message.fields
)
field_config_type_name = self.get_field_config_type(
message.name, parent_stack
)
field_config_macros.append(
self.generate_field_config_macro(message, field_config_type_name)
)
lines.append(
f"{body_indent}FORY_STRUCT({struct_type_name}, {field_members});"
)
else:
lines.append(f"{body_indent}FORY_STRUCT({struct_type_name});")
lines.append(f"{indent}}};")
return lines
def generate_union_definition(
self,
union: Union,
parent_stack: List[Message],
indent: str,
) -> List[str]:
"""Generate a C++ union class definition."""
lines: List[str] = []
class_name = union.name
body_indent = f"{indent} "
case_enum = f"{class_name}Case"
case_types = [
self.get_union_case_type(field, parent_stack) for field in union.fields
]
variant_type = f"std::variant<{', '.join(case_types)}>"
lines.append(f"{indent}class {class_name} final {{")
lines.append(f"{body_indent}public:")
lines.append(f"{body_indent} enum class {case_enum} : uint32_t {{")
for field in union.fields:
case_name = self.to_upper_snake_case(field.name)
lines.append(f"{body_indent} {case_name} = {field.number},")
lines.append(f"{body_indent} }};")
lines.append("")
lines.append(f"{body_indent} {class_name}() = default;")
lines.append("")
for field, case_type in zip(union.fields, case_types):
case_name = self.to_snake_case(field.name)
lines.append(
f"{body_indent} static {class_name} {case_name}({case_type} v) {{"
)
lines.append(
f"{body_indent} return {class_name}(std::in_place_type<{case_type}>, std::move(v));"
)
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(
f"{body_indent} {case_enum} {self.to_snake_case(class_name)}_case() const noexcept {{"
)
for i, (field, case_type) in enumerate(zip(union.fields, case_types)):
case_name = self.to_upper_snake_case(field.name)
if i < len(case_types) - 1:
lines.append(
f"{body_indent} if (std::holds_alternative<{case_type}>(value_)) return {case_enum}::{case_name};"
)
else:
lines.append(f"{body_indent} return {case_enum}::{case_name};")
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(
f"{body_indent} uint32_t {self.to_snake_case(class_name)}_case_id() const noexcept {{"
)
lines.append(
f"{body_indent} return static_cast<uint32_t>({self.to_snake_case(class_name)}_case());"
)
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent} uint32_t fory_case_id() const noexcept {{")
lines.append(
f"{body_indent} return {self.to_snake_case(class_name)}_case_id();"
)
lines.append(f"{body_indent} }}")
lines.append("")
for field, case_type in zip(union.fields, case_types):
case_snake = self.to_snake_case(field.name)
lines.append(f"{body_indent} bool is_{case_snake}() const noexcept {{")
lines.append(
f"{body_indent} return std::holds_alternative<{case_type}>(value_);"
)
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(
f"{body_indent} const {case_type}* as_{case_snake}() const noexcept {{"
)
lines.append(f"{body_indent} return std::get_if<{case_type}>(&value_);")
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent} {case_type}* as_{case_snake}() noexcept {{")
lines.append(f"{body_indent} return std::get_if<{case_type}>(&value_);")
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent} const {case_type}& {case_snake}() const {{")
lines.append(f"{body_indent} return std::get<{case_type}>(value_);")
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent} {case_type}& {case_snake}() {{")
lines.append(f"{body_indent} return std::get<{case_type}>(value_);")
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent} template <class Visitor>")
lines.append(f"{body_indent} decltype(auto) visit(Visitor&& vis) const {{")
lines.append(
f"{body_indent} return std::visit(std::forward<Visitor>(vis), value_);"
)
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent} template <class Visitor>")
lines.append(f"{body_indent} decltype(auto) visit(Visitor&& vis) {{")
lines.append(
f"{body_indent} return std::visit(std::forward<Visitor>(vis), value_);"
)
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(
f"{body_indent} bool operator==(const {class_name}& other) const {{"
)
lines.append(f"{body_indent} return value_ == other.value_;")
lines.append(f"{body_indent} }}")
lines.append("")
lines.append(f"{body_indent}private:")
lines.append(f"{body_indent} {variant_type} value_;")
lines.append("")
lines.append(f"{body_indent} template <class T, class... Args>")
lines.append(
f"{body_indent} explicit {class_name}(std::in_place_type_t<T> tag, Args&&... args)"
)
lines.append(
f"{body_indent} : value_(tag, std::forward<Args>(args)...) {{}}"
)
lines.append(f"{indent}}};")
return lines
def generate_union_macros(
self,
union: Union,
parent_stack: List[Message],
) -> List[str]:
"""Generate FORY_UNION metadata macros for a union."""
if not union.fields:
return []
lines: List[str] = []
if len(union.fields) <= 16:
union_type = self.get_namespaced_type_name(union.name, parent_stack)
lines.append(f"FORY_UNION({union_type},")
for index, field in enumerate(union.fields):
case_type = self.generate_namespaced_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
)
case_ctor = self.to_snake_case(field.name)
meta = self.get_union_field_meta(field)
suffix = "," if index + 1 < len(union.fields) else ""
lines.append(f" ({case_type}, {case_ctor}, {meta}){suffix}")
lines.append(");")
return lines
union_type = self.get_namespaced_type_name(union.name, parent_stack)
case_ids = ", ".join(str(field.number) for field in union.fields)
lines.append(f"FORY_UNION_IDS({union_type}, {case_ids});")
for field in union.fields:
case_type = self.generate_namespaced_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
)
case_ctor = self.to_snake_case(field.name)
meta = self.get_union_field_meta(field)
lines.append(
f"FORY_UNION_CASE({union_type}, {field.number}, {case_type}, {union_type}::{case_ctor}, {meta});"
)
return lines
def get_union_case_type(self, field: Field, parent_stack: List[Message]) -> str:
"""Return the C++ type for a union case."""
return self.generate_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
)
def schema_has_unions(self) -> bool:
if self.schema.unions:
return True
for message in self.schema.messages:
if self.message_has_unions(message):
return True
return False
def message_has_unions(self, message: Message) -> bool:
if message.nested_unions:
return True
for nested_msg in message.nested_messages:
if self.message_has_unions(nested_msg):
return True
return False
def collect_union_serializers(
self,
message: Message,
parent_stack: List[Message],
union_serializers: List[str],
) -> None:
"""Collect serializer specializations for nested unions."""
lineage = parent_stack + [message]
for nested_union in message.nested_unions:
union_serializers.extend(
self.generate_union_serializer(nested_union, lineage)
)
for nested_msg in message.nested_messages:
self.collect_union_serializers(nested_msg, lineage, union_serializers)
def generate_union_serializer(
self,
union: Union,
parent_stack: List[Message],
) -> List[str]:
"""Generate a C++ union serializer specialization."""
lines: List[str] = []
qualified_name = self.get_namespaced_type_name(union.name, parent_stack)
case_id_method = f"{self.to_snake_case(union.name)}_case_id"
case_types = [
self.generate_namespaced_type(
field.field_type,
False,
field.ref,
field.element_optional,
field.element_ref,
False,
False,
parent_stack,
)
for field in union.fields
]
if not union.fields:
return lines
default_field = union.fields[0]
default_ctor = self.to_snake_case(default_field.name)
default_type = case_types[0]
lines.append("template <>")
lines.append(f"struct Serializer<{qualified_name}> {{")
lines.append(" static constexpr TypeId type_id = TypeId::UNION;")
lines.append("")
lines.append(" static inline void write_type_info(WriteContext &ctx) {")
lines.append(
f" auto result = ctx.write_any_typeinfo(static_cast<uint32_t>(TypeId::TYPED_UNION), std::type_index(typeid({qualified_name})));"
)
lines.append(" if (FORY_PREDICT_FALSE(!result.ok())) {")
lines.append(" ctx.set_error(std::move(result).error());")
lines.append(" }")
lines.append(" }")
lines.append("")
lines.append(" static inline void read_type_info(ReadContext &ctx) {")
lines.append(
f" auto type_info_res = ctx.type_resolver().template get_type_info<{qualified_name}>();"
)
lines.append(" if (FORY_PREDICT_FALSE(!type_info_res.ok())) {")
lines.append(" ctx.set_error(std::move(type_info_res).error());")
lines.append(" return;")
lines.append(" }")
lines.append(" const TypeInfo *expected = type_info_res.value();")
lines.append(" const TypeInfo *remote = ctx.read_any_typeinfo(ctx.error());")
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return;")
lines.append(" }")
lines.append(" if (!remote || remote->type_id != expected->type_id) {")
lines.append(
" ctx.set_error(Error::type_mismatch(remote ? remote->type_id : 0u, expected->type_id));"
)
lines.append(" }")
lines.append(" }")
lines.append("")
lines.append(
f" static inline void write(const {qualified_name} &obj, WriteContext &ctx,"
)
lines.append(" RefMode ref_mode, bool write_type,")
lines.append(" bool has_generics = false) {")
lines.append(" (void)has_generics;")
lines.append(" if (ref_mode == RefMode::Tracking && ctx.track_ref()) {")
lines.append(" ctx.write_int8(REF_VALUE_FLAG);")
lines.append(" ctx.ref_writer().reserve_ref_id();")
lines.append(" } else if (ref_mode != RefMode::None) {")
lines.append(" ctx.write_int8(NOT_NULL_VALUE_FLAG);")
lines.append(" }")
lines.append(" if (write_type) {")
lines.append(" write_type_info(ctx);")
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return;")
lines.append(" }")
lines.append(" }")
lines.append(" write_data(obj, ctx);")
lines.append(" }")
lines.append("")
lines.append(
f" static inline void write_data(const {qualified_name} &obj, WriteContext &ctx) {{"
)
lines.append(f" ctx.write_varuint32(obj.{case_id_method}());")
lines.append(" obj.visit([&](const auto &value) {")
lines.append(" using Alt = std::decay_t<decltype(value)>;")
lines.append(
" Serializer<Alt>::write(value, ctx, RefMode::Tracking, true);"
)
lines.append(" });")
lines.append(" }")
lines.append("")
lines.append(
f" static inline void write_data_generic(const {qualified_name} &obj, WriteContext &ctx,"
)
lines.append(" bool has_generics) {")
lines.append(" (void)has_generics;")
lines.append(" write_data(obj, ctx);")
lines.append(" }")
lines.append("")
lines.append(
f" static inline {qualified_name} read(ReadContext &ctx, RefMode ref_mode,"
)
lines.append(" bool read_type) {")
lines.append(" int8_t ref_flag = NOT_NULL_VALUE_FLAG;")
lines.append(" if (ref_mode != RefMode::None) {")
lines.append(" ref_flag = ctx.read_int8(ctx.error());")
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return default_value();")
lines.append(" }")
lines.append(" }")
lines.append(" if (ref_flag == NULL_FLAG) {")
lines.append(
' ctx.set_error(Error::invalid_data("Null value encountered for union"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(" if (ref_flag == REF_FLAG) {")
lines.append(
' ctx.set_error(Error::invalid_ref("Unexpected reference flag for union"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(
" if (ref_flag != NOT_NULL_VALUE_FLAG && ref_flag != REF_VALUE_FLAG) {"
)
lines.append(
' ctx.set_error(Error::invalid_ref("Unknown ref flag for union"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(" if (ctx.track_ref() && ref_flag == REF_VALUE_FLAG) {")
lines.append(" ctx.ref_reader().reserve_ref_id();")
lines.append(" }")
lines.append(" if (read_type) {")
lines.append(" read_type_info(ctx);")
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return default_value();")
lines.append(" }")
lines.append(" }")
lines.append(" return read_data(ctx);")
lines.append(" }")
lines.append("")
lines.append(f" static inline {qualified_name} read_data(ReadContext &ctx) {{")
lines.append(" uint32_t case_id = ctx.read_varuint32(ctx.error());")
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return default_value();")
lines.append(" }")
lines.append(" switch (case_id) {")
for field, case_type in zip(union.fields, case_types):
case_ctor = self.to_snake_case(field.name)
lines.append(f" case {field.number}: {{")
lines.append(
f" auto value = Serializer<{case_type}>::read(ctx, RefMode::Tracking, true);"
)
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return default_value();")
lines.append(" }")
lines.append(
f" return {qualified_name}::{case_ctor}(std::move(value));"
)
lines.append(" }")
lines.append(" default: {")
lines.append(" int8_t ref_flag = ctx.read_int8(ctx.error());")
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return default_value();")
lines.append(" }")
lines.append(" if (ref_flag == NULL_FLAG) {")
lines.append(
' ctx.set_error(Error::invalid_data("Unknown union case id"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(" if (ref_flag == REF_FLAG) {")
lines.append(" (void)ctx.read_varuint32(ctx.error());")
lines.append(
' ctx.set_error(Error::invalid_data("Unknown union case id"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(
" if (ref_flag != NOT_NULL_VALUE_FLAG && ref_flag != REF_VALUE_FLAG) {"
)
lines.append(
' ctx.set_error(Error::invalid_data("Unknown reference flag in union value"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(
" const TypeInfo *type_info = ctx.read_any_typeinfo(ctx.error());"
)
lines.append(" if (FORY_PREDICT_FALSE(ctx.has_error())) {")
lines.append(" return default_value();")
lines.append(" }")
lines.append(" if (!type_info) {")
lines.append(
' ctx.set_error(Error::type_error("TypeInfo not found for union skip"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(" FieldType field_type;")
lines.append(" field_type.type_id = type_info->type_id;")
lines.append(" field_type.nullable = false;")
lines.append(" skip_field_value(ctx, field_type, RefMode::None);")
lines.append(
' ctx.set_error(Error::invalid_data("Unknown union case id"));'
)
lines.append(" return default_value();")
lines.append(" }")
lines.append(" }")
lines.append(" }")
lines.append("")
lines.append(
f" static inline {qualified_name} read_with_type_info(ReadContext &ctx, RefMode ref_mode,"
)
lines.append(" const TypeInfo &type_info) {")
lines.append(" (void)type_info;")
lines.append(" return read(ctx, ref_mode, false);")
lines.append(" }")
lines.append("")
lines.append("private:")
lines.append(f" static inline {qualified_name} default_value() {{")
lines.append(
f" return {qualified_name}::{default_ctor}({default_type}{{}});"
)
lines.append(" }")
lines.append("};")
lines.append("")
return lines
def generate_namespaced_type(
self,
field_type: FieldType,
nullable: bool = False,
ref: bool = False,
element_optional: bool = False,
element_ref: bool = False,
weak_ref: bool = False,
element_weak_ref: bool = False,
parent_stack: Optional[List[Message]] = None,
) -> str:
"""Generate C++ type string with package namespace."""
if isinstance(field_type, PrimitiveType):
if field_type.kind == PrimitiveKind.ANY:
return self.PRIMITIVE_MAP[field_type.kind]
base_type = self.PRIMITIVE_MAP[field_type.kind]
if nullable:
return f"std::optional<{base_type}>"
return base_type
if isinstance(field_type, NamedType):
type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
namespace = self.get_namespace()
if namespace:
type_name = f"{namespace}::{type_name}"
if ref:
wrapper = (
"fory::serialization::SharedWeak" if weak_ref else "std::shared_ptr"
)
type_name = f"{wrapper}<{type_name}>"
if nullable:
type_name = f"std::optional<{type_name}>"
return type_name
if isinstance(field_type, ListType):
element_type = self.generate_namespaced_type(
field_type.element_type,
element_optional,
element_ref,
False,
False,
element_weak_ref,
False,
parent_stack,
)
list_type = f"std::vector<{element_type}>"
if ref:
wrapper = (
"fory::serialization::SharedWeak" if weak_ref else "std::shared_ptr"
)
list_type = f"{wrapper}<{list_type}>"
if nullable:
list_type = f"std::optional<{list_type}>"
return list_type
if isinstance(field_type, MapType):
key_type = self.generate_namespaced_type(
field_type.key_type, False, False, False, False, parent_stack
)
value_type = self.generate_namespaced_type(
field_type.value_type,
False,
field_type.value_ref,
False,
False,
weak_ref and field_type.value_ref,
False,
parent_stack,
)
map_type = f"std::map<{key_type}, {value_type}>"
if ref:
wrapper = (
"fory::serialization::SharedWeak" if weak_ref else "std::shared_ptr"
)
map_type = f"{wrapper}<{map_type}>"
if nullable:
map_type = f"std::optional<{map_type}>"
return map_type
return "void*"
def generate_field_config_macro(
self,
message: Message,
qualified_name: str,
) -> str:
"""Generate FORY_FIELD_CONFIG macro for a message."""
entries = []
for field in message.fields:
field_name = self.get_field_member_name(field)
meta = self.get_field_meta(field)
entries.append(f"({field_name}, {meta})")
joined = ", ".join(entries)
return f"FORY_FIELD_CONFIG({qualified_name}, {joined});"
def get_field_meta(self, field: Field) -> str:
"""Build FieldMeta expression for a field."""
meta = "fory::F()"
is_any = (
isinstance(field.field_type, PrimitiveType)
and field.field_type.kind == PrimitiveKind.ANY
)
if field.tag_id is not None:
meta += f".id({field.tag_id})"
if field.optional or is_any:
meta += ".nullable()"
if field.ref:
meta += ".ref()"
encoding = self.get_encoding_config(field.field_type)
if encoding:
meta += encoding
array_type = self.get_array_type_config(field)
if array_type:
meta += array_type
return meta
def get_union_field_meta(self, field: Field) -> str:
"""Build FieldMeta expression for a union case."""
meta = f"fory::F({field.number})"
is_any = (
isinstance(field.field_type, PrimitiveType)
and field.field_type.kind == PrimitiveKind.ANY
)
if field.optional or is_any:
meta += ".nullable()"
if field.ref:
meta += ".ref()"
encoding = self.get_encoding_config(field.field_type)
if encoding:
meta += encoding
array_type = self.get_array_type_config(field)
if array_type:
meta += array_type
return meta
def get_encoding_config(self, field_type: FieldType) -> str:
"""Return encoding config for primitive types."""
kind = None
if isinstance(field_type, PrimitiveType):
kind = field_type.kind
elif isinstance(field_type, ListType) and isinstance(
field_type.element_type, PrimitiveType
):
kind = field_type.element_type.kind
if kind is None:
return ""
if kind in (
PrimitiveKind.INT32,
PrimitiveKind.INT64,
PrimitiveKind.UINT32,
PrimitiveKind.UINT64,
):
return ".fixed()"
if kind in (PrimitiveKind.VAR_UINT32, PrimitiveKind.VAR_UINT64):
return ".varint()"
if kind in (PrimitiveKind.TAGGED_INT64, PrimitiveKind.TAGGED_UINT64):
return ".tagged()"
return ""
def get_array_type_config(self, field: Field) -> str:
"""Return array type override for int8/uint8 arrays."""
if not isinstance(field.field_type, ListType):
return ""
if field.element_optional or field.element_ref:
return ""
element_type = field.field_type.element_type
if not isinstance(element_type, PrimitiveType):
return ""
if element_type.kind == PrimitiveKind.INT8:
return ".int8_array()"
if element_type.kind == PrimitiveKind.UINT8:
return ".uint8_array()"
return ""
def generate_type(
self,
field_type: FieldType,
nullable: bool = False,
ref: bool = False,
element_optional: bool = False,
element_ref: bool = False,
weak_ref: bool = False,
element_weak_ref: bool = False,
parent_stack: Optional[List[Message]] = None,
) -> str:
"""Generate C++ type string."""
if isinstance(field_type, PrimitiveType):
if field_type.kind == PrimitiveKind.ANY:
return self.PRIMITIVE_MAP[field_type.kind]
base_type = self.PRIMITIVE_MAP[field_type.kind]
if nullable:
return f"std::optional<{base_type}>"
return base_type
elif isinstance(field_type, NamedType):
type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
if ref:
wrapper = (
"fory::serialization::SharedWeak" if weak_ref else "std::shared_ptr"
)
type_name = f"{wrapper}<{type_name}>"
if nullable:
type_name = f"std::optional<{type_name}>"
return type_name
elif isinstance(field_type, ListType):
element_type = self.generate_type(
field_type.element_type,
element_optional,
element_ref,
False,
False,
element_weak_ref,
False,
parent_stack,
)
list_type = f"std::vector<{element_type}>"
if ref:
wrapper = (
"fory::serialization::SharedWeak" if weak_ref else "std::shared_ptr"
)
list_type = f"{wrapper}<{list_type}>"
if nullable:
list_type = f"std::optional<{list_type}>"
return list_type
elif isinstance(field_type, MapType):
key_type = self.generate_type(
field_type.key_type, False, False, False, False, parent_stack
)
value_type = self.generate_type(
field_type.value_type,
False,
field_type.value_ref,
False,
False,
weak_ref and field_type.value_ref,
False,
parent_stack,
)
map_type = f"std::map<{key_type}, {value_type}>"
if ref:
wrapper = (
"fory::serialization::SharedWeak" if weak_ref else "std::shared_ptr"
)
map_type = f"{wrapper}<{map_type}>"
if nullable:
map_type = f"std::optional<{map_type}>"
return map_type
return "void*"
def get_qualified_type_name(
self,
type_name: str,
parent_stack: List[Message],
) -> str:
"""Get the C++ qualified type name for nested types."""
if not parent_stack:
return type_name
prefix = "::".join(parent.name for parent in parent_stack)
return f"{prefix}::{type_name}"
def get_registration_type_name(
self,
type_name: str,
parent_stack: List[Message],
) -> str:
"""Get the dot-qualified type name used in registration."""
if not parent_stack:
return type_name
prefix = ".".join(parent.name for parent in parent_stack)
return f"{prefix}.{type_name}"
def resolve_nested_type_name(
self,
type_name: str,
parent_stack: Optional[List[Message]] = None,
) -> str:
"""Resolve nested type names to qualified C++ identifiers."""
if "." in type_name:
return type_name.replace(".", "::")
if not parent_stack:
return type_name
for i in range(len(parent_stack) - 1, -1, -1):
message = parent_stack[i]
if message.get_nested_type(type_name) is not None:
prefix = "::".join(parent.name for parent in parent_stack[: i + 1])
return f"{prefix}::{type_name}"
return type_name
def collect_includes(
self,
field_type: FieldType,
nullable: bool,
ref: bool,
includes: Set[str],
element_optional: bool = False,
element_ref: bool = False,
weak_ref: bool = False,
element_weak_ref: bool = False,
):
"""Collect required includes for a field type."""
if nullable:
includes.add("<optional>")
if ref:
includes.add("<memory>")
if weak_ref:
includes.add('"fory/serialization/weak_ptr_serializer.h"')
if isinstance(field_type, PrimitiveType):
if field_type.kind == PrimitiveKind.STRING:
includes.add("<string>")
elif field_type.kind == PrimitiveKind.BYTES:
includes.add("<vector>")
elif field_type.kind in (PrimitiveKind.DATE, PrimitiveKind.TIMESTAMP):
includes.add('"fory/serialization/temporal_serializers.h"')
elif field_type.kind == PrimitiveKind.ANY:
includes.add("<any>")
elif isinstance(field_type, ListType):
includes.add("<vector>")
self.collect_includes(
field_type.element_type,
element_optional,
element_ref,
includes,
False,
False,
element_weak_ref,
False,
)
elif isinstance(field_type, MapType):
includes.add("<map>")
if field_type.value_ref:
includes.add("<memory>")
if weak_ref:
includes.add('"fory/serialization/weak_ptr_serializer.h"')
self.collect_includes(field_type.key_type, False, False, includes)
self.collect_includes(
field_type.value_type,
False,
field_type.value_ref,
includes,
False,
False,
weak_ref and field_type.value_ref,
False,
)
def generate_registration(self) -> List[str]:
"""Generate the Fory registration function."""
lines = []
lines.append("inline void RegisterTypes(fory::serialization::Fory& fory) {")
# Register enums (top-level)
for enum in self.schema.enums:
self.generate_enum_registration(lines, enum, [])
# Register unions (top-level)
for union in self.schema.unions:
self.generate_union_registration(lines, union, [])
# Register messages (including nested types)
for message in self.schema.messages:
self.generate_message_registration(lines, message, [])
lines.append("}")
return lines
def generate_enum_registration(
self, lines: List[str], enum: Enum, parent_stack: List[Message]
):
"""Generate registration code for an enum."""
code_name = self.get_qualified_type_name(enum.name, parent_stack)
type_name = self.get_registration_type_name(enum.name, parent_stack)
if enum.type_id is not None:
lines.append(f" fory.register_enum<{code_name}>({enum.type_id});")
else:
ns = self.package or "default"
lines.append(f' fory.register_enum<{code_name}>("{ns}", "{type_name}");')
def generate_message_registration(
self, lines: List[str], message: Message, parent_stack: List[Message]
):
"""Generate registration code for a message and its nested types."""
code_name = self.get_qualified_type_name(message.name, parent_stack)
type_name = self.get_registration_type_name(message.name, parent_stack)
# Register nested enums first
for nested_enum in message.nested_enums:
self.generate_enum_registration(
lines, nested_enum, parent_stack + [message]
)
for nested_union in message.nested_unions:
self.generate_union_registration(
lines, nested_union, parent_stack + [message]
)
# Register nested messages recursively
for nested_msg in message.nested_messages:
self.generate_message_registration(
lines, nested_msg, parent_stack + [message]
)
# Register this message
if message.type_id is not None:
lines.append(f" fory.register_struct<{code_name}>({message.type_id});")
else:
ns = self.package or "default"
lines.append(
f' fory.register_struct<{code_name}>("{ns}", "{type_name}");'
)
def generate_union_registration(
self, lines: List[str], union: Union, parent_stack: List[Message]
):
"""Generate registration code for a union."""
code_name = self.get_qualified_type_name(union.name, parent_stack)
type_name = self.get_registration_type_name(union.name, parent_stack)
if union.type_id is not None:
lines.append(f" fory.register_union<{code_name}>({union.type_id});")
else:
ns = self.package or "default"
lines.append(
f' fory.register_union<{code_name}>("{ns}", "{type_name}");'
)