blob: 012de3123ce22750024f83f45eff1ffe9d657a55 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM FFI Python package."""
# order matters here so we need to skip isort here
# isort: skip_file
import sys
from typing import TYPE_CHECKING
def _is_config_mode() -> bool:
"""Check user is invoking the config CLI entry."""
if sys.argv[0].endswith("tvm-ffi-config"):
return True
# sys.orig_argv is available only after python 3.10
if hasattr(sys, "orig_argv"):
# Use orig_argv because Python strips the `tvm_ffi.config` from sys.argv when using -m.
argv = sys.orig_argv
for i, arg in enumerate(argv):
if arg == "-m" and i + 1 < len(argv) and argv[i + 1] == "tvm_ffi.config":
return True
return False
if TYPE_CHECKING or not _is_config_mode():
# Skip eager imports in CLI mode to avoid import
# overhead in tvm-ffi-config command
# HACK: try importing torch first, to avoid a potential
# symbol conflict when both torch and tvm_ffi are imported.
# This conflict can be reproduced in a very narrow scenario:
# 1. GitHub action on Windows X64
# 2. Python 3.12
# 3. torch 2.9.0
try:
import torch
except ImportError:
pass
# Always load base libtvm_ffi before any other imports
from . import libinfo
LIB = libinfo.load_lib_ctypes("apache-tvm-ffi", "tvm_ffi", "RTLD_GLOBAL")
# Enable package initialization
from .registry import (
register_object,
register_global_func,
get_global_func,
get_global_func_metadata,
remove_global_func,
init_ffi_api,
)
from ._dtype import dtype
from .core import Object, ObjectConvertible, Function, CAny, CContainerBase
from ._convert import convert, convert_func
from .error import register_error
from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Dict, List, Map
from .dataclasses.py_class import method
from .module import Module, system_lib, load_module
from .stream import StreamContext, get_raw_stream, use_raw_stream, use_torch_stream
from .structural import (
StructuralKey,
get_first_structural_mismatch,
structural_equal,
structural_hash,
)
from . import serialization
from . import access_path
from . import dataclasses
from . import structural
from . import cpp
# optional module to speedup dlpack conversion
from . import _optional_torch_c_dlpack
# import the dtype literals
from ._dtype import (
bool,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float64,
float32,
float16,
bfloat16,
float8_e4m3fn,
float8_e4m3fnuz,
float8_e5m2,
float8_e5m2fnuz,
float8_e8m0fnu,
float4_e2m1fnx2,
)
elif sys.platform.startswith("win32"):
# On Windows, load the library even in config CLI mode so the DLL search path
# is set correctly (needed in some cases when test still loads cython extensions).
from . import libinfo
LIB = libinfo.load_lib_ctypes("apache-tvm-ffi", "tvm_ffi", "RTLD_GLOBAL")
# normal version imports
try:
from ._version import __version__, __version_tuple__
except ImportError:
__version__ = "0.0.0.dev0"
__version_tuple__ = (0, 0, 0, "dev0", "7d34eb8ab.d20250913")
__all__ = [
"LIB",
"Array",
"DLDeviceType",
"Device",
"Dict",
"Function",
"List",
"Map",
"Module",
"Object",
"ObjectConvertible",
"Shape",
"StreamContext",
"StructuralKey",
"Tensor",
"__version__",
"__version_tuple__",
"access_path",
"convert",
"convert_func",
"cpp",
"dataclasses",
"device",
"dtype",
"from_dlpack",
"get_first_structural_mismatch",
"get_global_func",
"get_global_func_metadata",
"get_raw_stream",
"init_ffi_api",
"load_module",
"method",
"register_error",
"register_global_func",
"register_object",
"remove_global_func",
"serialization",
"structural",
"structural_equal",
"structural_hash",
"system_lib",
"use_raw_stream",
"use_torch_stream",
]