| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Code generation logic for the `tvm-ffi-stubgen` tool.""" |
| |
| from __future__ import annotations |
| |
| from typing import Callable |
| |
| from . import consts as C |
| from .file_utils import CodeBlock |
| from .utils import FuncInfo, ImportItem, InitConfig, ObjectInfo, Options |
| |
| |
| def _type_suffix_and_record( |
| ty_map: dict[str, str], |
| imports: list[ImportItem], |
| func_names: set[str] | None = None, |
| ) -> Callable[[str], str]: |
| def _run(name: str) -> str: |
| nonlocal ty_map, imports |
| name = ty_map.get(name, name) |
| suffix = name.rsplit(".", 1)[-1] |
| if "." in name: |
| alias = None |
| if func_names and suffix in func_names: |
| alias = f"_{suffix}" |
| imports.append(ImportItem(name, type_checking_only=True, alias=alias)) |
| if alias: |
| return alias |
| return suffix |
| |
| return _run |
| |
| |
| def generate_global_funcs( |
| code: CodeBlock, |
| global_funcs: list[FuncInfo], |
| ty_map: dict[str, str], |
| imports: list[ImportItem], |
| opt: Options, |
| ) -> None: |
| """Generate function signatures for global functions. |
| |
| It processes: global/${prefix}@${import_from="tvm_ffi") |
| """ |
| assert len(code.lines) >= 2 |
| if not global_funcs: |
| return |
| assert isinstance(code.param, tuple) |
| prefix, import_from = code.param |
| if not import_from: |
| import_from = "tvm_ffi" |
| imports.extend( |
| [ |
| ImportItem( |
| f"{import_from}.init_ffi_api", |
| type_checking_only=False, |
| alias="_FFI_INIT_FUNC", |
| ), |
| ImportItem( |
| "typing.TYPE_CHECKING", |
| type_checking_only=False, |
| ), |
| ] |
| ) |
| func_names = {f.schema.name.rsplit(".", 1)[-1] for f in global_funcs} |
| fn_ty_map = _type_suffix_and_record(ty_map, imports, func_names=func_names) |
| results: list[str] = [ |
| "# fmt: off", |
| f'_FFI_INIT_FUNC("{prefix}", __name__)', |
| "if TYPE_CHECKING:", |
| *[func.gen(fn_ty_map, indent=opt.indent) for func in global_funcs], |
| "# fmt: on", |
| ] |
| indent = " " * code.indent |
| code.lines = [ |
| code.lines[0], |
| *[indent + line for line in results], |
| code.lines[-1], |
| ] |
| |
| |
| def generate_object( |
| code: CodeBlock, |
| ty_map: dict[str, str], |
| imports: list[ImportItem], |
| opt: Options, |
| obj_info: ObjectInfo, |
| ) -> None: |
| """Generate a class definition for an object type. |
| |
| It processes: object/${type_key} |
| """ |
| assert len(code.lines) >= 2 |
| info = obj_info |
| method_names = {m.schema.name.rsplit(".", 1)[-1] for m in info.methods} |
| fn_ty_map = _type_suffix_and_record(ty_map, imports, func_names=method_names) |
| init_lines = info.gen_init(fn_ty_map, indent=opt.indent) |
| ffi_init_lines = info.gen_ffi_init(fn_ty_map, indent=opt.indent) |
| type_checking_lines = [ |
| *init_lines, |
| *ffi_init_lines, |
| *info.gen_methods(fn_ty_map, indent=opt.indent), |
| ] |
| if type_checking_lines: |
| imports.append( |
| ImportItem( |
| "typing.TYPE_CHECKING", |
| type_checking_only=False, |
| ) |
| ) |
| results = [ |
| "# fmt: off", |
| *info.gen_fields(fn_ty_map, indent=0), |
| "if TYPE_CHECKING:", |
| *type_checking_lines, |
| "# fmt: on", |
| ] |
| else: |
| results = [ |
| "# fmt: off", |
| *info.gen_fields(fn_ty_map, indent=0), |
| "# fmt: on", |
| ] |
| indent = " " * code.indent |
| code.lines = [ |
| code.lines[0], |
| *[indent + line for line in results], |
| code.lines[-1], |
| ] |
| |
| |
| def generate_import_section( |
| code: CodeBlock, |
| imports: list[ImportItem], |
| opt: Options, |
| ) -> None: |
| """Generate import statements for the types used in the stub. |
| |
| It processes: import-section |
| """ |
| imports_concrete: dict[str, list[ImportItem]] = {} |
| imports_ty_check: dict[str, list[ImportItem]] = {} |
| for item in imports: |
| if item.type_checking_only: |
| imports_ty_check.setdefault(item.mod, []).append(item) |
| else: |
| imports_concrete.setdefault(item.mod, []).append(item) |
| if imports_ty_check: |
| imports_concrete.setdefault("typing", []).append( |
| ImportItem("typing.TYPE_CHECKING", type_checking_only=True) |
| ) |
| |
| def _make_line(mod: str, items: list[ImportItem], indent: int) -> str: |
| items.sort(key=lambda item: item.name) |
| names = ", ".join(sorted(set(item.name_with_alias for item in items))) |
| indent_str = " " * indent |
| if mod: |
| return f"{indent_str}from {mod} import {names}" |
| else: |
| return f"{indent_str}import {names}" |
| |
| results: list[str] = [] |
| if imports_concrete: |
| results.extend( |
| _make_line(mod, imports_concrete[mod], indent=0) for mod in sorted(imports_concrete) |
| ) |
| if imports_ty_check: |
| results.append("if TYPE_CHECKING:") |
| results.extend( |
| _make_line(mod, imports_ty_check[mod], opt.indent) for mod in sorted(imports_ty_check) |
| ) |
| if results: |
| code.lines = [ |
| code.lines[0], |
| "# fmt: off", |
| "# isort: off", |
| "from __future__ import annotations", |
| *results, |
| "# isort: on", |
| "# fmt: on", |
| code.lines[-1], |
| ] |
| |
| |
| def generate_all(code: CodeBlock, names: set[str], opt: Options) -> None: |
| """Generate an `__all__` variable for the given names.""" |
| assert len(code.lines) >= 2 |
| if not names: |
| return |
| |
| indent = " " * code.indent |
| names = {f.rsplit(".", 1)[-1] for f in names} |
| |
| def _sort_key(name: str) -> tuple[int, str]: |
| if name.isupper(): |
| return (0, name) |
| if name and name[0].isupper() and not "_" in name: |
| return (1, name) |
| return (2, name) |
| |
| code.lines = [ |
| code.lines[0], |
| *[f'{indent}"{name}",' for name in sorted(names, key=_sort_key)], |
| code.lines[-1], |
| ] |
| |
| |
| def generate_export(code: CodeBlock) -> None: |
| """Generate an `__all__` variable for the given names.""" |
| assert len(code.lines) >= 2 |
| |
| mod = code.param |
| code.lines = [ |
| code.lines[0], |
| "# fmt: off", |
| "# isort: off", |
| f"from .{mod} import * # noqa: F403", |
| f"from .{mod} import __all__ as {mod}__all__", |
| 'if "__all__" not in globals():', |
| " __all__ = []", |
| f"__all__.extend({mod}__all__)", |
| "# isort: on", |
| "# fmt: on", |
| code.lines[-1], |
| ] |
| |
| |
| def generate_ffi_api( |
| code_blocks: list[CodeBlock], |
| ty_map: dict[str, str], |
| module_name: str, |
| object_infos: list[ObjectInfo], |
| init_cfg: InitConfig, |
| is_root: bool, |
| ) -> str: |
| """Generate the initial FFI API stub code for a given module.""" |
| # TODO(@junrus): New code is appended to the end of the file. |
| # We should consider a more sophisticated approach. |
| append = "" |
| |
| # Part 0. Imports |
| if not code_blocks: |
| append += f"""\"\"\"FFI API bindings for {module_name}.\"\"\"\n""" |
| if not any(code.kind == "import-section" for code in code_blocks): |
| append += C.PROMPT_IMPORT_SECTION |
| |
| # Part 1. Library loading |
| if is_root: |
| append += C._prompt_import_object("tvm_ffi.libinfo.load_lib_module", "_FFI_LOAD_LIB") |
| append += f"""LIB = _FFI_LOAD_LIB("{init_cfg.pkg}", "{init_cfg.shared_target}")\n""" |
| |
| # Part 2. Global functions |
| if not any(code.kind == "global" for code in code_blocks): |
| append += C._prompt_globals(module_name) |
| |
| # Part 3. Object types |
| if object_infos: |
| append += C._prompt_import_object("tvm_ffi.register_object", "_FFI_REG_OBJ") |
| |
| defined_type_keys = {info.type_key for info in object_infos if info.type_key} |
| for info in object_infos: |
| type_key = info.type_key |
| parent_type_key = info.parent_type_key |
| if type_key is None: |
| continue |
| # Canonicalize type key names |
| type_key = ty_map.get(type_key, type_key) |
| type_name = type_key.rsplit(".", 1)[-1] |
| parent_type_key = ( |
| ty_map.get(parent_type_key, parent_type_key) if parent_type_key else parent_type_key |
| ) |
| parent_type_name = parent_type_key.rsplit(".", 1)[-1] if parent_type_key else "Object" |
| # Import parent type keys if they are not defined in the current module |
| if parent_type_key and parent_type_key not in defined_type_keys: |
| parent_type_name = "_" + parent_type_key.replace(".", "_") |
| append += C._prompt_import_object(parent_type_key, parent_type_name) |
| # Generate class definition |
| append += C._prompt_class_def( |
| type_name, |
| type_key, |
| parent_type_name, |
| ) |
| # Part 4. __all__ |
| if not any(code.kind == "__all__" for code in code_blocks): |
| append += C.PROMPT_ALL_SECTION |
| return append |
| |
| |
| def generate_init( |
| code_blocks: list[CodeBlock], |
| module_name: str, |
| submodule: str = "_ffi_api", |
| ) -> str: |
| """Generate the `__init__.py` file for the `tvm_ffi` package.""" |
| code = f""" |
| {C.STUB_BEGIN} export/{submodule} |
| {C.STUB_END} |
| """ |
| if not code_blocks: |
| return f"""\"\"\"Package {module_name}.\"\"\"\n""" + code |
| if not any(code.kind == "export" for code in code_blocks): |
| return code |
| return "" |