blob: ea0f0bc50369c25899d0a790fae85b3fa45db1ea [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Stateful helpers for querying TVM FFI runtime metadata."""
from __future__ import annotations
import functools
import heapq
from collections import defaultdict
from tvm_ffi._ffi_api import GetRegisteredTypeKeys
from tvm_ffi.core import TypeSchema, _lookup_or_register_type_info_from_type_key
from tvm_ffi.registry import get_global_func_metadata, list_global_func_names
from . import consts as C
from .utils import FuncInfo, NamedTypeSchema, ObjectInfo
@functools.lru_cache(maxsize=None)
def object_info_from_type_key(type_key: str) -> ObjectInfo:
"""Construct an `ObjectInfo` from an object type key."""
type_info = _lookup_or_register_type_info_from_type_key(str(type_key))
assert type_info.type_key == type_key
return ObjectInfo.from_type_info(type_info)
def collect_global_funcs() -> dict[str, list[FuncInfo]]:
"""Collect global functions from TVM FFI's global registry."""
global_funcs: dict[str, list[FuncInfo]] = {}
for name in list_global_func_names():
try:
prefix, _ = name.rsplit(".", 1)
except ValueError:
print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function: {name}{C.TERM_RESET}")
else:
try:
global_funcs.setdefault(prefix, []).append(_func_info_from_global_name(name))
except Exception:
print(f"{C.TERM_YELLOW}[Skipped] Function has no type schema: {name}{C.TERM_RESET}")
for k in list(global_funcs.keys()):
global_funcs[k].sort(key=lambda x: x.schema.name)
return global_funcs
def collect_type_keys() -> dict[str, list[str]]:
"""Collect registered object type keys from TVM FFI's global registry."""
global_objects: dict[str, list[str]] = {}
for type_key in GetRegisteredTypeKeys():
try:
prefix, _ = type_key.rsplit(".", 1)
except ValueError:
pass
else:
global_objects.setdefault(prefix, []).append(type_key)
for k in list(global_objects.keys()):
global_objects[k].sort()
return global_objects
def toposort_objects(type_keys: list[str]) -> list[ObjectInfo]:
"""Collect ObjectInfo objects for type keys, topologically sorted by inheritance."""
# Remove duplicates while preserving order.
unique_type_keys = list(dict.fromkeys(type_keys))
infos: dict[str, ObjectInfo] = {
type_key: object_info_from_type_key(type_key) for type_key in unique_type_keys
}
child_types: dict[str, list[str]] = defaultdict(list)
in_degree: dict[str, int] = defaultdict(int)
for type_key, info in infos.items():
parent_type_key = info.parent_type_key
if parent_type_key in infos:
child_types[parent_type_key].append(type_key)
in_degree[type_key] += 1
in_degree[parent_type_key] += 0
else:
in_degree[type_key] += 0
for children in child_types.values():
children.sort()
queue: list[str] = [ty for ty, deg in in_degree.items() if deg == 0]
heapq.heapify(queue)
sorted_keys: list[str] = []
while queue:
type_key = heapq.heappop(queue)
sorted_keys.append(type_key)
for child_type_key in child_types[type_key]:
in_degree[child_type_key] -= 1
if in_degree[child_type_key] == 0:
heapq.heappush(queue, child_type_key)
assert len(sorted_keys) == len(infos)
return [infos[type_key] for type_key in sorted_keys]
@functools.lru_cache(maxsize=None)
def _func_info_from_global_name(name: str) -> FuncInfo:
"""Construct a `FuncInfo` from a global function name."""
return FuncInfo(
schema=NamedTypeSchema(
name=name,
schema=TypeSchema.from_json_str(get_global_func_metadata(name)["type_schema"]),
),
is_member=False,
)