blob: 47f2e7cf7822198e921d2f47a5e7ec80380adfca [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Java code generator."""
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union as TypingUnion
from fory_compiler.generators.base import BaseGenerator, GeneratedFile
from fory_compiler.frontend.utils import parse_idl_file
from fory_compiler.ir.ast import (
Message,
Enum,
Union,
Field,
FieldType,
PrimitiveType,
NamedType,
ListType,
MapType,
Schema,
)
from fory_compiler.ir.types import PrimitiveKind
class JavaGenerator(BaseGenerator):
"""Generates Java POJOs with Fory annotations."""
language_name = "java"
file_extension = ".java"
def get_java_package(self) -> Optional[str]:
"""Get the Java package name.
Priority:
1. Command-line override (options.package_override)
2. java_package option from FDL file
3. FDL package declaration
"""
if self.options.package_override:
return self.options.package_override
java_package = self.schema.get_option("java_package")
if java_package:
return java_package
return self.schema.package
def get_registration_class_name(self) -> str:
"""Get the generated registration class name."""
java_package = self.get_java_package()
if java_package:
parts = java_package.split(".")
return self.to_pascal_case(parts[-1]) + "ForyRegistration"
return "ForyRegistration"
def get_java_outer_classname(self) -> Optional[str]:
"""Get the Java outer classname if specified.
When set, all types are generated as inner classes of this outer class
in a single file (unless java_multiple_files is true).
"""
return self.schema.get_option("java_outer_classname")
def get_java_multiple_files(self) -> bool:
"""Check if java_multiple_files option is set to true.
When true, each top-level type gets its own file, even if
java_outer_classname is set.
"""
value = self.schema.get_option("java_multiple_files")
return value is True
def is_imported_type(self, type_def: object) -> bool:
"""Return True if a type definition comes from an imported IDL file."""
if not self.schema.source_file:
return False
location = getattr(type_def, "location", None)
if location is None or not location.file:
return False
try:
return (
Path(location.file).resolve() != Path(self.schema.source_file).resolve()
)
except Exception:
return location.file != self.schema.source_file
def split_imported_types(
self, items: List[object]
) -> Tuple[List[object], List[object]]:
imported: List[object] = []
local: List[object] = []
for item in items:
if self.is_imported_type(item):
imported.append(item)
else:
local.append(item)
return imported, local
def _normalize_import_path(self, path_str: str) -> str:
if not path_str:
return path_str
try:
return str(Path(path_str).resolve())
except Exception:
return path_str
def _load_schema(self, file_path: str) -> Optional[Schema]:
if not file_path:
return None
if not hasattr(self, "_schema_cache"):
self._schema_cache = {}
cache: Dict[Path, Schema] = self._schema_cache
path = Path(file_path).resolve()
if path in cache:
return cache[path]
try:
schema = parse_idl_file(path)
except Exception:
return None
cache[path] = schema
return schema
def _java_package_for_schema(self, schema: Schema) -> Optional[str]:
java_package = schema.get_option("java_package")
if java_package:
return java_package
return schema.package
def _registration_class_name_for_schema(self, schema: Schema) -> str:
java_package = self._java_package_for_schema(schema)
if java_package:
parts = java_package.split(".")
return self.to_pascal_case(parts[-1]) + "ForyRegistration"
return "ForyRegistration"
def _java_package_for_type(self, type_def: object) -> Optional[str]:
location = getattr(type_def, "location", None)
file_path = getattr(location, "file", None) if location else None
schema = self._load_schema(file_path)
if schema is None:
return None
return self._java_package_for_schema(schema)
def _collect_imported_packages(self) -> List[Tuple[str, str]]:
packages: Dict[str, str] = {}
for type_def in self.schema.enums + self.schema.unions + self.schema.messages:
if not self.is_imported_type(type_def):
continue
java_package = self._java_package_for_type(type_def)
if not java_package:
continue
if java_package in packages:
continue
schema = self._load_schema(
getattr(getattr(type_def, "location", None), "file", None)
)
if schema is None:
continue
packages[java_package] = self._registration_class_name_for_schema(schema)
ordered: List[Tuple[str, str]] = []
used: Set[str] = set()
if self.schema.source_file:
base_dir = Path(self.schema.source_file).resolve().parent
for imp in self.schema.imports:
candidate = self._normalize_import_path(
str((base_dir / imp.path).resolve())
)
schema = self._load_schema(candidate)
if schema is None:
continue
java_package = self._java_package_for_schema(schema)
if not java_package or java_package in used:
continue
reg_class = self._registration_class_name_for_schema(schema)
ordered.append((java_package, reg_class))
used.add(java_package)
for pkg, reg in sorted(packages.items()):
if pkg in used:
continue
ordered.append((pkg, reg))
return ordered
def generate_bytes_methods(self, class_name: str) -> List[str]:
reg_class = self.get_registration_class_name()
lines = []
lines.append("public byte[] toBytes() {")
lines.append(f" return {reg_class}.getFory().serialize(this);")
lines.append("}")
lines.append("")
lines.append(f"public static {class_name} fromBytes(byte[] bytes) {{")
lines.append(
f" return {reg_class}.getFory().deserialize(bytes, {class_name}.class);"
)
lines.append("}")
lines.append("")
return lines
# Mapping from FDL primitive types to Java types
PRIMITIVE_MAP = {
PrimitiveKind.BOOL: "boolean",
PrimitiveKind.INT8: "byte",
PrimitiveKind.INT16: "short",
PrimitiveKind.INT32: "int",
PrimitiveKind.VARINT32: "int",
PrimitiveKind.INT64: "long",
PrimitiveKind.VARINT64: "long",
PrimitiveKind.TAGGED_INT64: "long",
PrimitiveKind.UINT8: "byte",
PrimitiveKind.UINT16: "short",
PrimitiveKind.UINT32: "int",
PrimitiveKind.VAR_UINT32: "int",
PrimitiveKind.UINT64: "long",
PrimitiveKind.VAR_UINT64: "long",
PrimitiveKind.TAGGED_UINT64: "long",
PrimitiveKind.FLOAT16: "float",
PrimitiveKind.FLOAT32: "float",
PrimitiveKind.FLOAT64: "double",
PrimitiveKind.STRING: "String",
PrimitiveKind.BYTES: "byte[]",
PrimitiveKind.DATE: "java.time.LocalDate",
PrimitiveKind.TIMESTAMP: "java.time.Instant",
PrimitiveKind.DURATION: "java.time.Duration",
PrimitiveKind.DECIMAL: "java.math.BigDecimal",
PrimitiveKind.ANY: "Object",
}
# Boxed versions for nullable primitives
BOXED_MAP = {
PrimitiveKind.BOOL: "Boolean",
PrimitiveKind.INT8: "Byte",
PrimitiveKind.INT16: "Short",
PrimitiveKind.INT32: "Integer",
PrimitiveKind.VARINT32: "Integer",
PrimitiveKind.INT64: "Long",
PrimitiveKind.VARINT64: "Long",
PrimitiveKind.TAGGED_INT64: "Long",
PrimitiveKind.UINT8: "Byte",
PrimitiveKind.UINT16: "Short",
PrimitiveKind.UINT32: "Integer",
PrimitiveKind.VAR_UINT32: "Integer",
PrimitiveKind.UINT64: "Long",
PrimitiveKind.VAR_UINT64: "Long",
PrimitiveKind.TAGGED_UINT64: "Long",
PrimitiveKind.FLOAT16: "Float",
PrimitiveKind.FLOAT32: "Float",
PrimitiveKind.FLOAT64: "Double",
PrimitiveKind.ANY: "Object",
}
# Primitive array types for repeated numeric fields
PRIMITIVE_ARRAY_MAP = {
PrimitiveKind.BOOL: "boolean[]",
PrimitiveKind.INT8: "byte[]",
PrimitiveKind.INT16: "short[]",
PrimitiveKind.INT32: "int[]",
PrimitiveKind.VARINT32: "int[]",
PrimitiveKind.INT64: "long[]",
PrimitiveKind.VARINT64: "long[]",
PrimitiveKind.TAGGED_INT64: "long[]",
PrimitiveKind.UINT8: "byte[]",
PrimitiveKind.UINT16: "short[]",
PrimitiveKind.UINT32: "int[]",
PrimitiveKind.VAR_UINT32: "int[]",
PrimitiveKind.UINT64: "long[]",
PrimitiveKind.VAR_UINT64: "long[]",
PrimitiveKind.TAGGED_UINT64: "long[]",
PrimitiveKind.FLOAT16: "float[]",
PrimitiveKind.FLOAT32: "float[]",
PrimitiveKind.FLOAT64: "double[]",
}
def generate(self) -> List[GeneratedFile]:
"""Generate Java files for the schema.
Generation mode depends on options:
- java_multiple_files = true: Separate file per type (default behavior)
- java_outer_classname set + java_multiple_files = false: Single file with outer class
- Neither set: Separate file per type
"""
files = []
outer_classname = self.get_java_outer_classname()
multiple_files = self.get_java_multiple_files()
if outer_classname and not multiple_files:
# Generate all types in a single outer class file
files.append(self.generate_outer_class_file(outer_classname))
# Generate registration helper (with outer class prefix)
files.append(self.generate_registration_file(outer_classname))
else:
# Generate separate files for each type
# Generate enum files (top-level only, nested enums go inside message files)
for enum in self.schema.enums:
if self.is_imported_type(enum):
continue
files.append(self.generate_enum_file(enum))
# Generate union files (top-level only, nested unions go inside message files)
for union in self.schema.unions:
if self.is_imported_type(union):
continue
files.append(self.generate_union_file(union))
# Generate message files (includes nested types as inner classes)
for message in self.schema.messages:
if self.is_imported_type(message):
continue
files.append(self.generate_message_file(message))
# Generate registration helper
files.append(self.generate_registration_file())
return files
def get_java_package_path(self) -> str:
"""Get the Java package as a path."""
java_package = self.get_java_package()
if java_package:
return java_package.replace(".", "/")
return ""
def generate_enum_file(self, enum: Enum) -> GeneratedFile:
"""Generate a Java enum file."""
lines = []
java_package = self.get_java_package()
# License header
lines.append(self.get_license_header())
lines.append("")
# Package
if java_package:
lines.append(f"package {java_package};")
lines.append("")
# Enum declaration
lines.append(f"public enum {enum.name} {{")
# Enum values (strip prefix for scoped enums)
for i, value in enumerate(enum.values):
comma = "," if i < len(enum.values) - 1 else ";"
stripped_name = self.strip_enum_prefix(enum.name, value.name)
lines.append(f" {stripped_name}{comma}")
lines.append("}")
lines.append("")
# Build file path
path = self.get_java_package_path()
if path:
path = f"{path}/{enum.name}.java"
else:
path = f"{enum.name}.java"
return GeneratedFile(path=path, content="\n".join(lines))
def generate_union_file(self, union: Union) -> GeneratedFile:
"""Generate a Java union class file."""
lines = []
imports: Set[str] = set()
java_package = self.get_java_package()
self.collect_union_imports(union, imports)
lines.append(self.get_license_header())
lines.append("")
if java_package:
lines.append(f"package {java_package};")
lines.append("")
if imports:
for imp in sorted(imports):
lines.append(f"import {imp};")
lines.append("")
for line in self.generate_union_class(union):
lines.append(line)
path = self.get_java_package_path()
if path:
path = f"{path}/{union.name}.java"
else:
path = f"{union.name}.java"
return GeneratedFile(path=path, content="\n".join(lines))
def generate_message_file(self, message: Message) -> GeneratedFile:
"""Generate a Java class file for a message."""
lines = []
imports: Set[str] = set()
java_package = self.get_java_package()
# Collect imports (including from nested types)
self.collect_message_imports(message, imports)
# License header
lines.append(self.get_license_header())
lines.append("")
# Package
if java_package:
lines.append(f"package {java_package};")
lines.append("")
# Imports
if imports:
for imp in sorted(imports):
lines.append(f"import {imp};")
lines.append("")
# Class declaration
lines.append(f"public class {message.name} {{")
# Generate nested enums as static inner classes
for nested_enum in message.nested_enums:
for line in self.generate_nested_enum(nested_enum):
lines.append(f" {line}")
# Generate nested unions as static inner classes
for nested_union in message.nested_unions:
for line in self.generate_union_class(
nested_union, indent=0, nested=True, parent_stack=[message]
):
lines.append(f" {line}")
# Generate nested messages as static inner classes
for nested_msg in message.nested_messages:
for line in self.generate_nested_message(
nested_msg, indent=1, parent_stack=[message]
):
lines.append(f" {line}")
# Fields
for field in message.fields:
field_lines = self.generate_field(field)
for line in field_lines:
lines.append(f" {line}")
lines.append("")
# Default constructor
lines.append(f" public {message.name}() {{")
lines.append(" }")
lines.append("")
# Getters and setters
for field in message.fields:
getter_setter = self.generate_getter_setter(field)
for line in getter_setter:
lines.append(f" {line}")
# toBytes/fromBytes
for line in self.generate_bytes_methods(message.name):
lines.append(f" {line}")
# equals method
for line in self.generate_equals_method(message):
lines.append(f" {line}")
# hashCode method
for line in self.generate_hashcode_method(message):
lines.append(f" {line}")
lines.append("}")
lines.append("")
# Build file path
path = self.get_java_package_path()
if path:
path = f"{path}/{message.name}.java"
else:
path = f"{message.name}.java"
return GeneratedFile(path=path, content="\n".join(lines))
def generate_outer_class_file(self, outer_classname: str) -> GeneratedFile:
"""Generate a single Java file with all types as inner classes of an outer class.
This is used when java_outer_classname option is set.
"""
lines = []
imports: Set[str] = set()
java_package = self.get_java_package()
# Collect imports from all types
for message in self.schema.messages:
self.collect_message_imports(message, imports)
for enum in self.schema.enums:
pass # Enums don't need special imports
for union in self.schema.unions:
self.collect_union_imports(union, imports)
# License header
lines.append(self.get_license_header())
lines.append("")
# Package
if java_package:
lines.append(f"package {java_package};")
lines.append("")
# Imports
if imports:
for imp in sorted(imports):
lines.append(f"import {imp};")
lines.append("")
# Outer class declaration
lines.append(f"public final class {outer_classname} {{")
lines.append("")
lines.append(f" private {outer_classname}() {{")
lines.append(" // Prevent instantiation")
lines.append(" }")
lines.append("")
# Generate all top-level enums as static inner classes
for enum in self.schema.enums:
if self.is_imported_type(enum):
continue
for line in self.generate_nested_enum(enum):
lines.append(f" {line}")
# Generate all top-level unions as static inner classes
for union in self.schema.unions:
if self.is_imported_type(union):
continue
for line in self.generate_union_class(union, indent=0, nested=True):
lines.append(f" {line}")
# Generate all top-level messages as static inner classes
for message in self.schema.messages:
if self.is_imported_type(message):
continue
for line in self.generate_nested_message(message, indent=1):
lines.append(f" {line}")
lines.append("}")
lines.append("")
# Build file path
path = self.get_java_package_path()
if path:
path = f"{path}/{outer_classname}.java"
else:
path = f"{outer_classname}.java"
return GeneratedFile(path=path, content="\n".join(lines))
def collect_message_imports(self, message: Message, imports: Set[str]):
"""Collect imports for a message and all its nested types recursively."""
for field in message.fields:
self.collect_field_imports(field, imports)
# Add imports for equals/hashCode
imports.add("java.util.Objects")
if self.has_array_field_recursive(message):
imports.add("java.util.Arrays")
# Collect imports from nested messages
for nested_msg in message.nested_messages:
self.collect_message_imports(nested_msg, imports)
for nested_union in message.nested_unions:
self.collect_union_imports(nested_union, imports)
def collect_union_imports(self, union: Union, imports: Set[str]):
"""Collect imports for a union and its cases."""
imports.add("org.apache.fory.type.union.Union")
imports.add("org.apache.fory.type.Types")
imports.add("java.util.Objects")
for field in union.fields:
self.collect_type_imports(
field.field_type,
imports,
field.element_optional,
field.element_ref,
)
def has_array_field_recursive(self, message: Message) -> bool:
"""Check if message or any nested message has array fields."""
if self.has_array_field(message):
return True
for nested_msg in message.nested_messages:
if self.has_array_field_recursive(nested_msg):
return True
return False
def generate_nested_enum(self, enum: Enum) -> List[str]:
"""Generate a nested enum as a static inner class."""
lines = []
lines.append(f"public static enum {enum.name} {{")
# Enum values (strip prefix for scoped enums)
for i, value in enumerate(enum.values):
comma = "," if i < len(enum.values) - 1 else ";"
stripped_name = self.strip_enum_prefix(enum.name, value.name)
lines.append(f" {stripped_name}{comma}")
lines.append("}")
lines.append("")
return lines
def generate_union_class(
self,
union: Union,
indent: int = 0,
nested: bool = False,
parent_stack: Optional[List[Message]] = None,
) -> List[str]:
"""Generate a Java union class."""
lines: List[str] = []
ind = " " * indent
class_prefix = "public static final class" if nested else "public final class"
case_enum = f"{union.name}Case"
lines.append(f"{ind}{class_prefix} {union.name} extends Union {{")
lines.append(f"{ind} public enum {case_enum} {{")
for i, field in enumerate(union.fields):
comma = "," if i < len(union.fields) - 1 else ";"
case_name = self.to_upper_snake_case(field.name)
lines.append(f"{ind} {case_name}({field.number}){comma}")
lines.append(f"{ind} public final int id;")
lines.append(f"{ind} {case_enum}(int id) {{")
lines.append(f"{ind} this.id = id;")
lines.append(f"{ind} }}")
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} private static int resolveTypeId(int caseId) {{")
lines.append(f"{ind} switch (caseId) {{")
for field in union.fields:
type_id_expr = self.get_union_case_type_id_expr(field, parent_stack)
lines.append(f"{ind} case {field.number}:")
lines.append(f"{ind} return {type_id_expr};")
lines.append(f"{ind} default:")
lines.append(
f'{ind} throw new IllegalStateException("Unknown {union.name} case id: " + caseId);'
)
lines.append(f"{ind} }}")
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} private {union.name}(int caseId, Object v) {{")
lines.append(f"{ind} super(caseId, v, resolveTypeId(caseId));")
lines.append(f"{ind} if (v == null) {{")
lines.append(f"{ind} throw new NullPointerException();")
lines.append(f"{ind} }}")
lines.append(f"{ind} get{union.name}Case();")
lines.append(f"{ind} }}")
lines.append("")
for field in union.fields:
case_name = self.to_pascal_case(field.name)
case_enum_name = self.to_upper_snake_case(field.name)
case_type = self.get_union_case_type(field)
lines.append(
f"{ind} public static {union.name} of{case_name}({case_type} v) {{"
)
lines.append(
f"{ind} return new {union.name}({case_enum}.{case_enum_name}.id, v);"
)
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} public {case_enum} get{union.name}Case() {{")
lines.append(f"{ind} switch (index) {{")
for field in union.fields:
case_enum_name = self.to_upper_snake_case(field.name)
lines.append(f"{ind} case {field.number}:")
lines.append(f"{ind} return {case_enum}.{case_enum_name};")
lines.append(f"{ind} default:")
lines.append(
f'{ind} throw new IllegalStateException("Unknown {union.name} case id: " + index);'
)
lines.append(f"{ind} }}")
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} public int get{union.name}CaseId() {{")
lines.append(f"{ind} return index;")
lines.append(f"{ind} }}")
lines.append("")
for field in union.fields:
case_name = self.to_pascal_case(field.name)
case_enum_name = self.to_upper_snake_case(field.name)
case_type = self.get_union_case_type(field)
cast_type = self.get_union_case_cast_type(field)
lines.append(f"{ind} public boolean has{case_name}() {{")
lines.append(
f"{ind} return index == {case_enum}.{case_enum_name}.id;"
)
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} public {case_type} get{case_name}() {{")
lines.append(
f"{ind} if (index != {case_enum}.{case_enum_name}.id) {{"
)
lines.append(
f'{ind} throw new IllegalStateException("{union.name} is not {case_enum_name}");'
)
lines.append(f"{ind} }}")
lines.append(f"{ind} return ({cast_type}) value;")
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} public void set{case_name}({case_type} v) {{")
if not self.is_java_primitive_type(case_type):
lines.append(f"{ind} if (v == null) {{")
lines.append(f"{ind} throw new NullPointerException();")
lines.append(f"{ind} }}")
lines.append(f"{ind} this.index = {case_enum}.{case_enum_name}.id;")
lines.append(f"{ind} this.value = v;")
type_id_expr = self.get_union_case_type_id_expr(field, parent_stack)
lines.append(f"{ind} this.typeId = {type_id_expr};")
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} @Override")
lines.append(f"{ind} public boolean equals(Object o) {{")
lines.append(f"{ind} if (this == o) {{")
lines.append(f"{ind} return true;")
lines.append(f"{ind} }}")
lines.append(f"{ind} if (!(o instanceof {union.name})) {{")
lines.append(f"{ind} return false;")
lines.append(f"{ind} }}")
lines.append(f"{ind} {union.name} that = ({union.name}) o;")
lines.append(
f"{ind} return index == that.index && Objects.equals(value, that.value);"
)
lines.append(f"{ind} }}")
lines.append("")
lines.append(f"{ind} @Override")
lines.append(f"{ind} public int hashCode() {{")
lines.append(f"{ind} return Objects.hash(index, value);")
lines.append(f"{ind} }}")
lines.append("")
for line in self.generate_bytes_methods(union.name):
lines.append(f"{ind} {line}")
lines.append(f"{ind}}}")
lines.append("")
return lines
def get_union_case_type(self, field: Field) -> str:
"""Return the Java type for a union case."""
return self.generate_type(
field.field_type,
False,
field.element_optional,
field.element_ref,
)
def get_union_case_cast_type(self, field: Field) -> str:
"""Return the Java cast type for a union case value."""
if isinstance(field.field_type, PrimitiveType):
boxed = self.BOXED_MAP.get(field.field_type.kind)
if boxed is not None:
return boxed
return self.PRIMITIVE_MAP[field.field_type.kind]
return self.get_union_case_type(field)
def get_union_case_type_id_expr(
self, field: Field, parent_stack: Optional[List[Message]]
) -> str:
"""Return the Java expression for a union case value type id."""
if isinstance(field.field_type, PrimitiveType):
kind = field.field_type.kind
primitive_type_ids = {
PrimitiveKind.BOOL: "Types.BOOL",
PrimitiveKind.INT8: "Types.INT8",
PrimitiveKind.INT16: "Types.INT16",
PrimitiveKind.INT32: "Types.INT32",
PrimitiveKind.VARINT32: "Types.VARINT32",
PrimitiveKind.INT64: "Types.INT64",
PrimitiveKind.VARINT64: "Types.VARINT64",
PrimitiveKind.TAGGED_INT64: "Types.TAGGED_INT64",
PrimitiveKind.UINT8: "Types.UINT8",
PrimitiveKind.UINT16: "Types.UINT16",
PrimitiveKind.UINT32: "Types.UINT32",
PrimitiveKind.VAR_UINT32: "Types.VAR_UINT32",
PrimitiveKind.UINT64: "Types.UINT64",
PrimitiveKind.VAR_UINT64: "Types.VAR_UINT64",
PrimitiveKind.TAGGED_UINT64: "Types.TAGGED_UINT64",
PrimitiveKind.FLOAT16: "Types.FLOAT16",
PrimitiveKind.FLOAT32: "Types.FLOAT32",
PrimitiveKind.FLOAT64: "Types.FLOAT64",
PrimitiveKind.STRING: "Types.STRING",
PrimitiveKind.BYTES: "Types.BINARY",
PrimitiveKind.DATE: "Types.DATE",
PrimitiveKind.TIMESTAMP: "Types.TIMESTAMP",
PrimitiveKind.ANY: "Types.UNKNOWN",
}
return primitive_type_ids.get(kind, "Types.UNKNOWN")
if isinstance(field.field_type, ListType):
return "Types.LIST"
if isinstance(field.field_type, MapType):
return "Types.MAP"
if isinstance(field.field_type, NamedType):
type_def = self.resolve_named_type(field.field_type.name, parent_stack)
if isinstance(type_def, Enum):
if type_def.type_id is None:
return "Types.NAMED_ENUM"
return f"({type_def.type_id} << 8) | Types.ENUM"
if isinstance(type_def, Union):
if type_def.type_id is None:
return "Types.NAMED_UNION"
return f"({type_def.type_id} << 8) | Types.UNION"
if isinstance(type_def, Message):
if type_def.type_id is None:
return "Types.NAMED_STRUCT"
return f"({type_def.type_id} << 8) | Types.STRUCT"
return "Types.UNKNOWN"
def resolve_named_type(
self, name: str, parent_stack: Optional[List[Message]]
) -> Optional[TypingUnion[Message, Enum, Union]]:
"""Resolve a named type to a schema definition."""
parts = name.split(".")
if len(parts) > 1:
current = self.find_top_level_type(parts[0])
for part in parts[1:]:
if isinstance(current, Message):
current = current.get_nested_type(part)
else:
return None
return current
if parent_stack:
for msg in reversed(parent_stack):
nested = msg.get_nested_type(name)
if nested is not None:
return nested
return self.find_top_level_type(name)
def find_top_level_type(
self, name: str
) -> Optional[TypingUnion[Message, Enum, Union]]:
"""Find a top-level type definition by name."""
for msg in self.schema.messages:
if msg.name == name:
return msg
for enum in self.schema.enums:
if enum.name == name:
return enum
for union in self.schema.unions:
if union.name == name:
return union
return None
def is_java_primitive_type(self, type_name: str) -> bool:
"""Return True if the Java type name is a primitive type."""
return type_name in {
"boolean",
"byte",
"short",
"int",
"long",
"float",
"double",
"char",
}
def generate_nested_message(
self,
message: Message,
indent: int = 1,
parent_stack: Optional[List[Message]] = None,
) -> List[str]:
"""Generate a nested message as a static inner class."""
lines = []
lineage = (parent_stack or []) + [message]
# Class declaration
lines.append(f"public static class {message.name} {{")
# Generate nested enums
for nested_enum in message.nested_enums:
for line in self.generate_nested_enum(nested_enum):
lines.append(f" {line}")
# Generate nested unions
for nested_union in message.nested_unions:
for line in self.generate_union_class(
nested_union,
indent=0,
nested=True,
parent_stack=lineage,
):
lines.append(f" {line}")
# Generate nested messages (recursively)
for nested_msg in message.nested_messages:
for line in self.generate_nested_message(
nested_msg,
indent=1,
parent_stack=lineage,
):
lines.append(f" {line}")
# Fields
for field in message.fields:
field_lines = self.generate_field(field)
for line in field_lines:
lines.append(f" {line}")
lines.append("")
# Default constructor
lines.append(f" public {message.name}() {{")
lines.append(" }")
lines.append("")
# Getters and setters
for field in message.fields:
getter_setter = self.generate_getter_setter(field)
for line in getter_setter:
lines.append(f" {line}")
# toBytes/fromBytes
for line in self.generate_bytes_methods(message.name):
lines.append(f" {line}")
# equals method
for line in self.generate_equals_method(message):
lines.append(f" {line}")
# hashCode method
for line in self.generate_hashcode_method(message):
lines.append(f" {line}")
lines.append("}")
lines.append("")
return lines
def generate_field(self, field: Field) -> List[str]:
"""Generate field declaration with annotations."""
lines = []
# Generate @ForyField annotation if needed
annotations = []
is_any = (
isinstance(field.field_type, PrimitiveType)
and field.field_type.kind == PrimitiveKind.ANY
)
nullable = field.optional or is_any
if field.tag_id is not None:
annotations.append(f"id = {field.tag_id}")
if nullable:
annotations.append("nullable = true")
if field.ref:
annotations.append("ref = true")
if annotations:
lines.append(f"@ForyField({', '.join(annotations)})")
array_annotation = self.get_array_annotation(field)
if array_annotation:
lines.append(array_annotation)
int_annotation = self.get_integer_annotation(field.field_type)
if int_annotation:
lines.append(int_annotation)
# Field type
java_type = self.generate_type(
field.field_type,
nullable,
field.element_optional,
field.element_ref,
)
lines.append(f"private {java_type} {self.to_camel_case(field.name)};")
lines.append("")
return lines
def generate_getter_setter(self, field: Field) -> List[str]:
"""Generate getter and setter for a field."""
lines = []
is_any = (
isinstance(field.field_type, PrimitiveType)
and field.field_type.kind == PrimitiveKind.ANY
)
nullable = field.optional or is_any
java_type = self.generate_type(
field.field_type,
nullable,
field.element_optional,
field.element_ref,
)
field_name = self.to_camel_case(field.name)
pascal_name = self.to_pascal_case(field.name)
# Getter
lines.append(f"public {java_type} get{pascal_name}() {{")
lines.append(f" return {field_name};")
lines.append("}")
lines.append("")
# Setter
lines.append(f"public void set{pascal_name}({java_type} {field_name}) {{")
lines.append(f" this.{field_name} = {field_name};")
lines.append("}")
lines.append("")
return lines
def generate_type(
self,
field_type: FieldType,
nullable: bool = False,
element_optional: bool = False,
element_ref: bool = False,
) -> str:
"""Generate Java type string."""
if isinstance(field_type, PrimitiveType):
if nullable and field_type.kind in self.BOXED_MAP:
return self.BOXED_MAP[field_type.kind]
return self.PRIMITIVE_MAP[field_type.kind]
elif isinstance(field_type, NamedType):
named_type = self.schema.get_type(field_type.name)
if named_type is not None and self.is_imported_type(named_type):
java_package = self._java_package_for_type(named_type)
if java_package:
return f"{java_package}.{field_type.name}"
return field_type.name
elif isinstance(field_type, ListType):
# Use primitive arrays for numeric types
if isinstance(field_type.element_type, PrimitiveType):
if (
field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
and not element_optional
and not element_ref
):
return self.PRIMITIVE_ARRAY_MAP[field_type.element_type.kind]
element_type = self.generate_type(field_type.element_type, True)
if self.is_ref_target_type(field_type.element_type):
ref_annotation = "@Ref" if element_ref else "@Ref(enable=false)"
element_type = f"{ref_annotation} {element_type}"
return f"List<{element_type}>"
elif isinstance(field_type, MapType):
key_type = self.generate_type(field_type.key_type, True)
value_type = self.generate_type(field_type.value_type, True)
if self.is_ref_target_type(field_type.value_type):
ref_annotation = (
"@Ref" if field_type.value_ref else "@Ref(enable=false)"
)
value_type = f"{ref_annotation} {value_type}"
return f"Map<{key_type}, {value_type}>"
return "Object"
def collect_type_imports(
self,
field_type: FieldType,
imports: Set[str],
element_optional: bool = False,
element_ref: bool = False,
):
"""Collect required imports for a field type."""
if isinstance(field_type, PrimitiveType):
if field_type.kind == PrimitiveKind.DATE:
imports.add("java.time.LocalDate")
elif field_type.kind == PrimitiveKind.TIMESTAMP:
imports.add("java.time.Instant")
elif isinstance(field_type, ListType):
# Primitive arrays don't need List import
if isinstance(field_type.element_type, PrimitiveType):
if (
field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
and not element_optional
and not element_ref
):
return # No import needed for primitive arrays
imports.add("java.util.List")
if self.is_ref_target_type(field_type.element_type):
imports.add("org.apache.fory.annotation.Ref")
self.collect_type_imports(field_type.element_type, imports)
elif isinstance(field_type, MapType):
imports.add("java.util.Map")
if self.is_ref_target_type(field_type.value_type):
imports.add("org.apache.fory.annotation.Ref")
self.collect_type_imports(field_type.key_type, imports)
self.collect_type_imports(field_type.value_type, imports)
def collect_field_imports(self, field: Field, imports: Set[str]):
"""Collect imports for a field, including list modifiers."""
is_any = (
isinstance(field.field_type, PrimitiveType)
and field.field_type.kind == PrimitiveKind.ANY
)
self.collect_type_imports(
field.field_type,
imports,
field.element_optional,
field.element_ref,
)
self.collect_integer_imports(field.field_type, imports)
self.collect_array_imports(field, imports)
if field.optional or field.ref or field.tag_id is not None or is_any:
imports.add("org.apache.fory.annotation.ForyField")
def is_ref_target_type(self, field_type: FieldType) -> bool:
if not isinstance(field_type, NamedType):
return False
resolved = self.schema.get_type(field_type.name)
return isinstance(resolved, (Message, Union))
def collect_array_imports(self, field: Field, imports: Set[str]) -> None:
"""Collect imports for primitive array type annotations."""
if not isinstance(field.field_type, ListType):
return
if field.element_optional or field.element_ref:
return
element_type = field.field_type.element_type
if not isinstance(element_type, PrimitiveType):
return
kind = element_type.kind
if kind == PrimitiveKind.INT8:
imports.add("org.apache.fory.annotation.Int8ArrayType")
elif kind == PrimitiveKind.UINT8:
imports.add("org.apache.fory.annotation.Uint8ArrayType")
elif kind == PrimitiveKind.UINT16:
imports.add("org.apache.fory.annotation.Uint16ArrayType")
elif kind in (PrimitiveKind.UINT32, PrimitiveKind.VAR_UINT32):
imports.add("org.apache.fory.annotation.Uint32ArrayType")
elif kind in (
PrimitiveKind.UINT64,
PrimitiveKind.VAR_UINT64,
PrimitiveKind.TAGGED_UINT64,
):
imports.add("org.apache.fory.annotation.Uint64ArrayType")
def collect_integer_imports(self, field_type: FieldType, imports: Set[str]) -> None:
"""Collect imports for integer encoding annotations."""
if not isinstance(field_type, PrimitiveType):
return
kind = field_type.kind
if kind in (PrimitiveKind.INT32,):
imports.add("org.apache.fory.annotation.Int32Type")
if kind in (PrimitiveKind.INT64, PrimitiveKind.TAGGED_INT64):
imports.add("org.apache.fory.annotation.Int64Type")
imports.add("org.apache.fory.config.LongEncoding")
if kind in (PrimitiveKind.UINT8,):
imports.add("org.apache.fory.annotation.Uint8Type")
if kind in (PrimitiveKind.UINT16,):
imports.add("org.apache.fory.annotation.Uint16Type")
if kind in (PrimitiveKind.UINT32, PrimitiveKind.VAR_UINT32):
imports.add("org.apache.fory.annotation.Uint32Type")
if kind in (
PrimitiveKind.UINT64,
PrimitiveKind.VAR_UINT64,
PrimitiveKind.TAGGED_UINT64,
):
imports.add("org.apache.fory.annotation.Uint64Type")
imports.add("org.apache.fory.config.LongEncoding")
def get_integer_annotation(self, field_type: FieldType) -> Optional[str]:
"""Return integer encoding annotation for a field type."""
if not isinstance(field_type, PrimitiveType):
return None
kind = field_type.kind
if kind == PrimitiveKind.INT32:
return "@Int32Type(compress = false)"
if kind == PrimitiveKind.INT64:
return "@Int64Type(encoding = LongEncoding.FIXED)"
if kind == PrimitiveKind.TAGGED_INT64:
return "@Int64Type(encoding = LongEncoding.TAGGED)"
if kind == PrimitiveKind.UINT8:
return "@Uint8Type"
if kind == PrimitiveKind.UINT16:
return "@Uint16Type"
if kind == PrimitiveKind.UINT32:
return "@Uint32Type(compress = false)"
if kind == PrimitiveKind.VAR_UINT32:
return "@Uint32Type(compress = true)"
if kind == PrimitiveKind.UINT64:
return "@Uint64Type(encoding = LongEncoding.FIXED)"
if kind == PrimitiveKind.VAR_UINT64:
return "@Uint64Type(encoding = LongEncoding.VARINT)"
if kind == PrimitiveKind.TAGGED_UINT64:
return "@Uint64Type(encoding = LongEncoding.TAGGED)"
return None
def get_array_annotation(self, field: Field) -> Optional[str]:
"""Return array type annotation for primitive list fields."""
if not isinstance(field.field_type, ListType):
return None
if field.element_optional or field.element_ref:
return None
element_type = field.field_type.element_type
if not isinstance(element_type, PrimitiveType):
return None
kind = element_type.kind
if kind == PrimitiveKind.INT8:
return "@Int8ArrayType"
if kind == PrimitiveKind.UINT8:
return "@Uint8ArrayType"
if kind == PrimitiveKind.UINT16:
return "@Uint16ArrayType"
if kind in (PrimitiveKind.UINT32, PrimitiveKind.VAR_UINT32):
return "@Uint32ArrayType"
if kind in (
PrimitiveKind.UINT64,
PrimitiveKind.VAR_UINT64,
PrimitiveKind.TAGGED_UINT64,
):
return "@Uint64ArrayType"
return None
def has_array_field(self, message: Message) -> bool:
"""Check if message has any array fields (byte[] or primitive arrays)."""
for field in message.fields:
if isinstance(field.field_type, PrimitiveType):
if field.field_type.kind == PrimitiveKind.BYTES:
return True
elif self.is_primitive_array_field(field):
return True
return False
def is_primitive_array_field(self, field: Field) -> bool:
"""Check if field is a primitive array type."""
if isinstance(field.field_type, PrimitiveType):
return field.field_type.kind == PrimitiveKind.BYTES
if isinstance(field.field_type, ListType):
if isinstance(field.field_type.element_type, PrimitiveType):
return (
field.field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
and not field.element_optional
and not field.element_ref
)
return False
def generate_equals_method(self, message: Message) -> List[str]:
"""Generate equals() method for a message."""
lines = []
lines.append("@Override")
lines.append("public boolean equals(Object o) {")
lines.append(" if (this == o) return true;")
lines.append(" if (o == null || getClass() != o.getClass()) return false;")
lines.append(f" {message.name} that = ({message.name}) o;")
if not message.fields:
lines.append(" return true;")
else:
comparisons = []
for field in message.fields:
field_name = self.to_camel_case(field.name)
if self.is_primitive_array_field(field):
comparisons.append(
f"Arrays.equals({field_name}, that.{field_name})"
)
elif isinstance(field.field_type, PrimitiveType):
kind = field.field_type.kind
if kind in (PrimitiveKind.FLOAT32,):
comparisons.append(
f"Float.compare({field_name}, that.{field_name}) == 0"
)
elif kind in (PrimitiveKind.FLOAT64,):
comparisons.append(
f"Double.compare({field_name}, that.{field_name}) == 0"
)
elif (
kind
in (
PrimitiveKind.BOOL,
PrimitiveKind.INT8,
PrimitiveKind.INT16,
PrimitiveKind.INT32,
PrimitiveKind.INT64,
)
and not field.optional
):
comparisons.append(f"{field_name} == that.{field_name}")
else:
comparisons.append(
f"Objects.equals({field_name}, that.{field_name})"
)
else:
comparisons.append(
f"Objects.equals({field_name}, that.{field_name})"
)
if len(comparisons) == 1:
lines.append(f" return {comparisons[0]};")
else:
lines.append(f" return {comparisons[0]}")
for i, comp in enumerate(comparisons[1:], 1):
if i == len(comparisons) - 1:
lines.append(f" && {comp};")
else:
lines.append(f" && {comp}")
lines.append("}")
lines.append("")
return lines
def generate_hashcode_method(self, message: Message) -> List[str]:
"""Generate hashCode() method for a message."""
lines = []
lines.append("@Override")
lines.append("public int hashCode() {")
if not message.fields:
lines.append(" return 0;")
else:
hash_args = []
array_fields = []
for field in message.fields:
field_name = self.to_camel_case(field.name)
if self.is_primitive_array_field(field):
array_fields.append(field_name)
else:
hash_args.append(field_name)
if array_fields and hash_args:
lines.append(f" int result = Objects.hash({', '.join(hash_args)});")
for arr in array_fields:
lines.append(f" result = 31 * result + Arrays.hashCode({arr});")
lines.append(" return result;")
elif array_fields:
if len(array_fields) == 1:
lines.append(f" return Arrays.hashCode({array_fields[0]});")
else:
lines.append(
f" int result = Arrays.hashCode({array_fields[0]});"
)
for arr in array_fields[1:]:
lines.append(
f" result = 31 * result + Arrays.hashCode({arr});"
)
lines.append(" return result;")
else:
lines.append(f" return Objects.hash({', '.join(hash_args)});")
lines.append("}")
lines.append("")
return lines
def generate_registration_file(
self, outer_classname: Optional[str] = None
) -> GeneratedFile:
"""Generate the Fory registration helper class.
Args:
outer_classname: If set, all type references will be prefixed with this outer class.
"""
lines = []
java_package = self.get_java_package()
# Determine class name
class_name = self.get_registration_class_name()
# License header
lines.append(self.get_license_header())
lines.append("")
# Package
if java_package:
lines.append(f"package {java_package};")
lines.append("")
# Imports
lines.append("import org.apache.fory.Fory;")
lines.append("import org.apache.fory.ThreadSafeFory;")
lines.append("import org.apache.fory.pool.SimpleForyPool;")
lines.append("")
# Class
lines.append(f"public class {class_name} {{")
lines.append("")
lines.append(" static ThreadSafeFory getFory() {")
lines.append(" return Holder.FORY;")
lines.append(" }")
lines.append("")
lines.append(" private static ThreadSafeFory createFory() {")
lines.append(
" ThreadSafeFory fory = new SimpleForyPool(c -> Fory.builder().withXlang(true).withRefTracking(true).build());"
)
imported_packages = self._collect_imported_packages()
if imported_packages:
lines.append(" fory.registerCallback(f -> {")
for java_package, reg_class in imported_packages:
lines.append(f" {java_package}.{reg_class}.register(f);")
lines.append(" register(f);")
lines.append(" });")
else:
lines.append(" fory.registerCallback(f -> register(f));")
lines.append(" return fory;")
lines.append(" }")
lines.append("")
lines.append(" private static class Holder {")
lines.append(" private static final ThreadSafeFory FORY = createFory();")
lines.append(" }")
lines.append("")
# When outer_classname is set, all top-level types become inner classes
type_prefix = outer_classname if outer_classname else ""
local_enums = [e for e in self.schema.enums if not self.is_imported_type(e)]
local_unions = [u for u in self.schema.unions if not self.is_imported_type(u)]
local_messages = [
m for m in self.schema.messages if not self.is_imported_type(m)
]
lines.append(" public static void register(Fory fory) {")
# Register enums (top-level)
for enum in local_enums:
self.generate_enum_registration(lines, enum, type_prefix)
# Register unions (top-level)
for union in local_unions:
self.generate_union_registration(lines, union, type_prefix)
# Register messages (top-level and nested)
for message in local_messages:
self.generate_message_registration(lines, message, type_prefix)
lines.append(" }")
lines.append("}")
lines.append("")
# Build file path
path = self.get_java_package_path()
if path:
path = f"{path}/{class_name}.java"
else:
path = f"{class_name}.java"
return GeneratedFile(path=path, content="\n".join(lines))
def generate_enum_registration(
self, lines: List[str], enum: Enum, parent_path: str
):
"""Generate registration code for an enum."""
# In Java, nested class references use OuterClass.InnerClass
class_ref = f"{parent_path}.{enum.name}" if parent_path else enum.name
type_name = class_ref if parent_path else enum.name
if enum.type_id is not None:
lines.append(f" fory.register({class_ref}.class, {enum.type_id});")
else:
# Use FDL package for namespace (consistent across languages)
ns = self.schema.package or "default"
lines.append(
f' fory.register({class_ref}.class, "{ns}", "{type_name}");'
)
def generate_message_registration(
self, lines: List[str], message: Message, parent_path: str
):
"""Generate registration code for a message and its nested types."""
# In Java, nested class references use OuterClass.InnerClass
class_ref = f"{parent_path}.{message.name}" if parent_path else message.name
type_name = class_ref if parent_path else message.name
if message.type_id is not None:
lines.append(
f" fory.register({class_ref}.class, {message.type_id});"
)
else:
# Use FDL package for namespace (consistent across languages)
ns = self.schema.package or "default"
lines.append(
f' fory.register({class_ref}.class, "{ns}", "{type_name}");'
)
# Register nested enums
for nested_enum in message.nested_enums:
self.generate_enum_registration(lines, nested_enum, class_ref)
# Register nested unions
for nested_union in message.nested_unions:
self.generate_union_registration(lines, nested_union, class_ref)
# Register nested messages
for nested_msg in message.nested_messages:
self.generate_message_registration(lines, nested_msg, class_ref)
def generate_union_registration(
self, lines: List[str], union: Union, parent_path: str
):
"""Generate registration code for a union."""
class_ref = f"{parent_path}.{union.name}" if parent_path else union.name
type_name = union.name
serializer_ref = (
f"new org.apache.fory.serializer.UnionSerializer(fory, {class_ref}.class)"
)
if union.type_id is not None:
lines.append(
f" fory.registerUnion({class_ref}.class, {union.type_id}, {serializer_ref});"
)
else:
ns = self.schema.package or "default"
if parent_path:
ns = f"{ns}.{parent_path}"
lines.append(
f' fory.registerUnion({class_ref}.class, "{ns}", "{type_name}", {serializer_ref});'
)