| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| //! Provides the [`Module`] type and methods for working with runtime TVM modules. |
| |
| use std::{ |
| convert::TryInto, |
| ffi::CString, |
| os::raw::{c_char, c_int}, |
| path::Path, |
| ptr, |
| }; |
| |
| use failure::Error; |
| use tvm_common::ffi; |
| |
| use crate::{errors, function::Function}; |
| |
| const ENTRY_FUNC: &'static str = "__tvm_main__"; |
| |
| /// Wrapper around TVM module handle which contains an entry function. |
| /// The entry function can be applied to an imported module through [`entry_func`]. |
| /// |
| /// [`entry_func`]:struct.Module.html#method.entry_func |
| #[derive(Debug, Clone)] |
| pub struct Module { |
| pub(crate) handle: ffi::TVMModuleHandle, |
| entry_func: Option<Function>, |
| } |
| |
| impl Module { |
| pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { |
| Self { |
| handle, |
| entry_func: None, |
| } |
| } |
| |
| pub fn entry(&mut self) -> Option<&Function> { |
| if self.entry_func.is_none() { |
| self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); |
| } |
| self.entry_func.as_ref() |
| } |
| |
| /// Gets a function by name from a registered module. |
| pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> { |
| let name = CString::new(name)?; |
| let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; |
| check_call!(ffi::TVMModGetFunction( |
| self.handle, |
| name.as_ptr() as *const c_char, |
| query_import as c_int, |
| &mut fhandle as *mut _ |
| )); |
| ensure!( |
| !fhandle.is_null(), |
| errors::NullHandleError { |
| name: format!("{}", name.into_string()?) |
| } |
| ); |
| Ok(Function::new(fhandle)) |
| } |
| |
| /// Imports a dependent module such as `.ptx` for gpu. |
| pub fn import_module(&self, dependent_module: Module) { |
| check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) |
| } |
| |
| /// Loads a module shared library from path. |
| pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> { |
| let ext = CString::new( |
| path.as_ref() |
| .extension() |
| .unwrap_or(std::ffi::OsStr::new("")) |
| .to_str() |
| .ok_or_else(|| { |
| format_err!("Bad module load path: `{}`.", path.as_ref().display()) |
| })?, |
| )?; |
| let func = Function::get("module._LoadFromFile").expect("API function always exists"); |
| let cpath = |
| CString::new(path.as_ref().to_str().ok_or_else(|| { |
| format_err!("Bad module load path: `{}`.", path.as_ref().display()) |
| })?)?; |
| let ret: Module = call_packed!(func, cpath.as_c_str(), ext.as_c_str())?.try_into()?; |
| Ok(ret) |
| } |
| |
| /// Checks if a target device is enabled for a module. |
| pub fn enabled(&self, target: &str) -> bool { |
| let func = Function::get("module._Enabled").expect("API function always exists"); |
| // `unwrap` is safe here because if there is any error during the |
| // function call, it would occur in `call_packed!`. |
| let tgt = CString::new(target).unwrap(); |
| let ret: i64 = call_packed!(func, tgt.as_c_str()) |
| .unwrap() |
| .try_into() |
| .unwrap(); |
| ret != 0 |
| } |
| |
| /// Returns the underlying module handle. |
| pub fn handle(&self) -> ffi::TVMModuleHandle { |
| self.handle |
| } |
| } |
| |
| impl Drop for Module { |
| fn drop(&mut self) { |
| check_call!(ffi::TVMModFree(self.handle)); |
| } |
| } |