| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # coding: utf-8 |
| # pylint: disable=invalid-name, import-outside-toplevel |
| """Base library for TVM FFI.""" |
| import sys |
| import os |
| import ctypes |
| import numpy as np |
| from . import libinfo |
| |
| # ---------------------------- |
| # library loading |
| # ---------------------------- |
| string_types = (str,) |
| integer_types = (int, np.int32) |
| numeric_types = integer_types + (float, np.float16, np.float32) |
| |
| # this function is needed for python3 |
| # to convert ctypes.char_p .value back to python str |
| if sys.platform == "win32": |
| |
| def _py_str(x): |
| try: |
| return x.decode("utf-8") |
| except UnicodeDecodeError: |
| encoding = "cp" + str(ctypes.cdll.kernel32.GetACP()) |
| return x.decode(encoding) |
| |
| py_str = _py_str |
| else: |
| py_str = lambda x: x.decode("utf-8") |
| |
| |
| def _load_lib(): |
| """Load libary by searching possible path.""" |
| lib_path = libinfo.find_lib_path() |
| # The dll search path need to be added explicitly in |
| # windows after python 3.8 |
| if sys.platform.startswith("win32") and sys.version_info >= (3, 8): |
| for path in libinfo.get_dll_directories(): |
| os.add_dll_directory(path) |
| lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) |
| lib.TVMGetLastError.restype = ctypes.c_char_p |
| return lib, os.path.basename(lib_path[0]) |
| |
| |
| try: |
| # The following import is needed for TVM to work with pdb |
| import readline # pylint: disable=unused-import |
| except ImportError: |
| pass |
| |
| # version number |
| __version__ = libinfo.__version__ |
| # library instance |
| _LIB, _LIB_NAME = _load_lib() |
| |
| # Whether we are runtime only |
| _RUNTIME_ONLY = "runtime" in _LIB_NAME |
| |
| # The FFI mode of TVM |
| _FFI_MODE = os.environ.get("TVM_FFI", "auto") |
| |
| |
| # ---------------------------- |
| # helper function in ctypes. |
| # ---------------------------- |
| def c_str(string): |
| """Create ctypes char * from a python string |
| Parameters |
| ---------- |
| string : string type |
| python string |
| |
| Returns |
| ------- |
| str : c_char_p |
| A char pointer that can be passed to C API |
| """ |
| return ctypes.c_char_p(string.encode("utf-8")) |
| |
| |
| def c_array(ctype, values): |
| """Create ctypes array from a python array |
| |
| Parameters |
| ---------- |
| ctype : ctypes data type |
| data type of the array we want to convert to |
| |
| values : tuple or list |
| data content |
| |
| Returns |
| ------- |
| out : ctypes array |
| Created ctypes array |
| """ |
| return (ctype * len(values))(*values) |
| |
| |
| def decorate(func, fwrapped): |
| """A wrapper call of decorator package, differs to call time |
| |
| Parameters |
| ---------- |
| func : function |
| The original function |
| |
| fwrapped : function |
| The wrapped function |
| """ |
| import decorator |
| |
| return decorator.decorate(func, fwrapped) |
| |
| |
| # ----------------------------------------- |
| # Base code for structured error handling. |
| # ----------------------------------------- |
| # Maps error type to its constructor |
| ERROR_TYPE = {} |
| |
| |
| class TVMError(RuntimeError): |
| """Default error thrown by TVM functions. |
| |
| TVMError will be raised if you do not give any error type specification, |
| """ |
| |
| |
| def register_error(func_name=None, cls=None): |
| """Register an error class so it can be recognized by the ffi error handler. |
| |
| Parameters |
| ---------- |
| func_name : str or function or class |
| The name of the error function. |
| |
| cls : function |
| The function to create the class |
| |
| Returns |
| ------- |
| fregister : function |
| Register function if f is not specified. |
| |
| Examples |
| -------- |
| .. code-block:: python |
| |
| @tvm.error.register_error |
| class MyError(RuntimeError): |
| pass |
| |
| err_inst = tvm.error.create_ffi_error("MyError: xyz") |
| assert isinstance(err_inst, MyError) |
| """ |
| if callable(func_name): |
| cls = func_name |
| func_name = cls.__name__ |
| |
| def register(mycls): |
| """internal register function""" |
| err_name = func_name if isinstance(func_name, str) else mycls.__name__ |
| ERROR_TYPE[err_name] = mycls |
| return mycls |
| |
| if cls is None: |
| return register |
| return register(cls) |
| |
| |
| def _valid_error_name(name): |
| """Check whether name is a valid error name.""" |
| return all(x.isalnum() or x in "_." for x in name) |
| |
| |
| def _find_error_type(line): |
| """Find the error name given the first line of the error message. |
| |
| Parameters |
| ---------- |
| line : str |
| The first line of error message. |
| |
| Returns |
| ------- |
| name : str The error name |
| """ |
| if sys.platform == "win32": |
| # Stack traces aren't logged on Windows due to a DMLC limitation, |
| # so we should try to get the underlying error another way. |
| # DMLC formats errors "[timestamp] file:line: ErrorMessage" |
| # ErrorMessage is usually formatted "ErrorType: message" |
| # We can try to extract the error type using the final ":" |
| end_pos = line.rfind(":") |
| if end_pos == -1: |
| return None |
| start_pos = line.rfind(":", 0, end_pos) |
| if start_pos == -1: |
| err_name = line[:end_pos].strip() |
| else: |
| err_name = line[start_pos + 1 : end_pos].strip() |
| if _valid_error_name(err_name): |
| return err_name |
| return None |
| |
| end_pos = line.find(":") |
| if end_pos == -1: |
| return None |
| err_name = line[:end_pos] |
| if _valid_error_name(err_name): |
| return err_name |
| return None |
| |
| |
| def c2pyerror(err_msg): |
| """Translate C API error message to python style. |
| |
| Parameters |
| ---------- |
| err_msg : str |
| The error message. |
| |
| Returns |
| ------- |
| new_msg : str |
| Translated message. |
| |
| err_type : str |
| Detected error type. |
| """ |
| arr = err_msg.split("\n") |
| if arr[-1] == "": |
| arr.pop() |
| err_type = _find_error_type(arr[0]) |
| trace_mode = False |
| stack_trace = [] |
| message = [] |
| for line in arr: |
| if trace_mode: |
| if line.startswith(" ") and len(stack_trace) > 0: |
| stack_trace[-1] += "\n" + line |
| elif line.startswith(" "): |
| stack_trace.append(line) |
| else: |
| trace_mode = False |
| if not trace_mode: |
| if line.startswith("Stack trace"): |
| trace_mode = True |
| else: |
| message.append(line) |
| out_msg = "" |
| if stack_trace: |
| out_msg += "Traceback (most recent call last):\n" |
| out_msg += "\n".join(reversed(stack_trace)) + "\n" |
| out_msg += "\n".join(message) |
| return out_msg, err_type |
| |
| |
| def py2cerror(err_msg): |
| """Translate python style error message to C style. |
| |
| Parameters |
| ---------- |
| err_msg : str |
| The error message. |
| |
| Returns |
| ------- |
| new_msg : str |
| Translated message. |
| """ |
| arr = err_msg.split("\n") |
| if arr[-1] == "": |
| arr.pop() |
| trace_mode = False |
| stack_trace = [] |
| message = [] |
| for line in arr: |
| if trace_mode: |
| if line.startswith(" "): |
| stack_trace.append(line) |
| else: |
| trace_mode = False |
| if not trace_mode: |
| if line.find("Traceback") != -1: |
| trace_mode = True |
| else: |
| message.append(line) |
| # Remove the first error name if there are two of them. |
| # RuntimeError: MyErrorName: message => MyErrorName: message |
| head_arr = message[0].split(":", 3) |
| if len(head_arr) >= 3 and _valid_error_name(head_arr[1].strip()): |
| head_arr[1] = head_arr[1].strip() |
| message[0] = ":".join(head_arr[1:]) |
| # reverse the stack trace. |
| out_msg = "\n".join(message) |
| if stack_trace: |
| out_msg += "\nStack trace:\n" |
| out_msg += "\n".join(reversed(stack_trace)) + "\n" |
| return out_msg |
| |
| |
| def get_last_ffi_error(): |
| """Create error object given result of TVMGetLastError. |
| |
| Returns |
| ------- |
| err : object |
| The error object based on the err_msg |
| """ |
| c_err_msg = py_str(_LIB.TVMGetLastError()) |
| py_err_msg, err_type = c2pyerror(c_err_msg) |
| if err_type is not None and err_type.startswith("tvm.error."): |
| err_type = err_type[10:] |
| return ERROR_TYPE.get(err_type, TVMError)(py_err_msg) |
| |
| |
| def check_call(ret): |
| """Check the return value of C API call |
| |
| This function will raise exception when error occurs. |
| Wrap every API call with this function |
| |
| Parameters |
| ---------- |
| ret : int |
| return value from API calls |
| """ |
| if ret != 0: |
| raise get_last_ffi_error() |