blob: b829d9183703e5383a2b360d627d3728dc6cc9e7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::convert::TryFrom;
use tvm_common::{
errors::ValueDowncastError,
ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
try_downcast,
};
use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
impl<'a> From<&'a $type> for TVMArgValue<'a> {
fn from(arg: &'a $type) -> Self {
TVMArgValue::$variant(arg.handle() as $inner_type)
}
}
impl<'a> From<&'a mut $type> for TVMArgValue<'a> {
fn from(arg: &'a mut $type) -> Self {
TVMArgValue::$variant(arg.handle() as $inner_type)
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<$type, Self::Error> {
try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(val) })
}
}
impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $type {
type Error = ValueDowncastError;
fn try_from(val: &'a TVMArgValue<'v>) -> Result<$type, Self::Error> {
try_downcast!(val -> $type, |TVMArgValue::$variant(val)| { $ctor(*val) })
}
}
impl From<$type> for TVMRetValue {
fn from(val: $type) -> TVMRetValue {
TVMRetValue::$variant(val.handle() as $inner_type)
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = ValueDowncastError;
fn try_from(val: TVMRetValue) -> Result<$type, Self::Error> {
try_downcast!(val -> $type, |TVMRetValue::$variant(val)| { $ctor(val) })
}
}
};
}
impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
#[cfg(test)]
mod tests {
use std::{convert::TryInto, str::FromStr};
use tvm_common::{TVMByteArray, TVMContext, TVMType};
use super::*;
#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = TVMByteArray::from(w.as_slice());
let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
assert_eq!(
tvm.data(),
w.iter().map(|e| *e).collect::<Vec<u8>>().as_slice()
);
}
#[test]
fn ty() {
let t = TVMType::from_str("int32").unwrap();
let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t);
}
#[test]
fn ctx() {
let c = TVMContext::from_str("gpu").unwrap();
let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
assert_eq!(tvm, c);
}
}