blob: ee6ed05a74f776863bd50e6601e8f2d6409bde0a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import
"""Function configuration API."""
import ctypes
import traceback
from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..runtime_ctypes import DataType, TVMByteArray, Device, ObjectRValueRef
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, ArgTypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _device_to_int64
from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object
PackedFuncHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it is not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def _make_packed_func(handle, is_global):
"""Make a packed function class"""
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
obj.is_global = is_global
obj.handle = handle
return obj
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
"""ctypes function"""
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
msg = py2cerror(msg)
_LIB.TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0:
raise get_last_ffi_error()
_ = temp_args
_ = rv
return 0
handle = PackedFuncHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
return _make_packed_func(handle, False)
def _make_tvm_args(args, temp_args):
"""Pack arguments into c args tvm call accept"""
num_args = len(args)
values = (TVMValue * num_args)()
type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if isinstance(arg, ObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.OBJECT_HANDLE
elif arg is None:
values[i].v_handle = None
type_codes[i] = ArgTypeCode.NULL
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (
ArgTypeCode.NDARRAY_HANDLE if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE
)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = ArgTypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = ArgTypeCode.INT
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = ArgTypeCode.FLOAT
elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = ArgTypeCode.STR
elif isinstance(arg, Device):
values[i].v_int64 = _device_to_int64(arg)
type_codes[i] = ArgTypeCode.DLDEVICE
elif isinstance(arg, (bytearray, bytes)):
# from_buffer only taeks in bytearray.
if isinstance(arg, bytes):
byte_arr = bytearray(arg)
temp_args.append(byte_arr)
arg = byte_arr
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg), ctypes.POINTER(ctypes.c_byte)
)
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = ArgTypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = ArgTypeCode.STR
elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = _FUNC_CONVERT_TO_OBJECT(arg)
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.OBJECT_HANDLE
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.MODULE_HANDLE
elif isinstance(arg, PackedFuncBase):
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg
type_codes[i] = ArgTypeCode.HANDLE
elif isinstance(arg, ObjectRValueRef):
values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p)
type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
class PackedFuncBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global):
"""Initialize the function with handle
Parameters
----------
handle : PackedFuncHandle
the handle to the underlying function.
is_global : bool
Whether this is a global function in python
"""
self.handle = handle
self.is_global = is_global
def __del__(self):
if not self.is_global and _LIB is not None:
if _LIB.TVMFuncFree(self.handle) != 0:
raise get_last_ffi_error()
def __call__(self, *args):
"""Call the function with positional arguments
args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
if (
_LIB.TVMFuncCall(
self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
!= 0
):
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
def __init_handle_by_constructor__(fconstructor, args):
"""Initialize handle by constructor"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
if (
_LIB.TVMFuncCall(
fconstructor.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
!= 0
):
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE
handle = ret_val.v_handle
return handle
def _return_module(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, ModuleHandle):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, PackedFuncHandle):
handle = PackedFuncHandle(handle)
return _CLASS_PACKED_FUNC(handle, False)
def _get_global_func(name, allow_missing=False):
handle = PackedFuncHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return _make_packed_func(handle, False)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
# setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[ArgTypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
C_TO_PY_ARG_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH[ArgTypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, ArgTypeCode.MODULE_HANDLE
)
C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = _wrap_arg_func(
lambda x: _make_array(x.v_handle, False, True), ArgTypeCode.NDARRAY_HANDLE
)
_CLASS_MODULE = None
_CLASS_PACKED_FUNC = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_packed_func(packed_func_class):
global _CLASS_PACKED_FUNC
_CLASS_PACKED_FUNC = packed_func_class
def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _CLASS_OBJECT_GENERIC
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object