| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; |
| |
| use failure::Error; |
| use ndarray; |
| use tvm_common::{ |
| array::{DataType, TVMContext}, |
| ffi::{ |
| DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, |
| DLDataTypeCode_kDLUInt, DLTensor, |
| }, |
| }; |
| |
| use crate::allocator::Allocation; |
| |
| /// A `Storage` is a container which holds `Tensor` data. |
| #[derive(PartialEq)] |
| pub enum Storage<'a> { |
| /// A `Storage` which owns its contained bytes. |
| Owned(Allocation), |
| |
| /// A view of an existing `Storage`. |
| View(&'a mut [u8], usize), // ptr, align |
| } |
| |
| impl<'a> Storage<'a> { |
| pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> { |
| Ok(Storage::Owned(Allocation::new(size, align)?)) |
| } |
| |
| pub fn as_mut_ptr(&self) -> *mut u8 { |
| match self { |
| Storage::Owned(alloc) => alloc.as_mut_ptr(), |
| Storage::View(slice, _) => slice.as_ptr() as *mut u8, |
| } |
| } |
| |
| pub fn size(&self) -> usize { |
| match self { |
| Storage::Owned(alloc) => alloc.size(), |
| Storage::View(slice, _) => slice.len(), |
| } |
| } |
| |
| pub fn align(&self) -> usize { |
| match self { |
| Storage::Owned(alloc) => alloc.align(), |
| Storage::View(_, align) => *align, |
| } |
| } |
| |
| pub fn as_ptr(&self) -> *const u8 { |
| self.as_mut_ptr() as *const _ |
| } |
| |
| /// Returns a `Storage::View` which points to an owned `Storage::Owned`. |
| pub fn view(&self) -> Storage<'a> { |
| match self { |
| Storage::Owned(alloc) => Storage::View( |
| unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, |
| self.align(), |
| ), |
| Storage::View(slice, _) => Storage::View( |
| unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, |
| self.align(), |
| ), |
| } |
| } |
| |
| pub fn is_owned(&self) -> bool { |
| match self { |
| Storage::Owned(_) => true, |
| _ => false, |
| } |
| } |
| |
| /// Returns an owned version of this storage via cloning. |
| pub fn to_owned(&self) -> Storage<'static> { |
| let s = Storage::new(self.size(), Some(self.align())).unwrap(); |
| unsafe { |
| s.as_mut_ptr() |
| .copy_from_nonoverlapping(self.as_ptr(), self.size()); |
| } |
| s |
| } |
| } |
| |
| impl<'d, 's, T> From<&'d [T]> for Storage<'s> { |
| fn from(data: &'d [T]) -> Self { |
| let data = unsafe { |
| slice::from_raw_parts_mut( |
| data.as_ptr() as *const u8 as *mut u8, |
| data.len() * mem::size_of::<T>() as usize, |
| ) |
| }; |
| Storage::View(data, mem::align_of::<T>()) |
| } |
| } |
| |
| /// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. |
| /// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or |
| /// converted to `ndarray::Array` for non-TVM processing. |
| /// |
| /// # Examples |
| /// |
| /// ``` |
| /// extern crate ndarray; |
| /// |
| /// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); |
| /// let mut a: Tensor = a_nd.into(); |
| /// let mut a_dl: DLTensor = (&mut t).into(); |
| /// call_packed!(tvm_fn, &mut a_dl); |
| /// |
| /// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. |
| /// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); |
| /// ``` |
| #[derive(PartialEq)] |
| pub struct Tensor<'a> { |
| /// The bytes which contain the data this `Tensor` represents. |
| pub(crate) data: Storage<'a>, |
| pub(crate) ctx: TVMContext, |
| pub(crate) dtype: DataType, |
| pub(crate) shape: Vec<i64>, |
| // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h |
| /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. |
| pub(crate) strides: Option<Vec<usize>>, |
| pub(crate) byte_offset: isize, |
| /// The number of elements in the `Tensor`. |
| pub(crate) size: usize, |
| } |
| |
| unsafe impl<'a> Send for Tensor<'a> {} |
| |
| impl<'a> Tensor<'a> { |
| pub fn shape(&self) -> Vec<i64> { |
| self.shape.clone() |
| } |
| |
| /// Returns the data of this `Tensor` as a `Vec`. |
| /// |
| /// # Panics |
| /// |
| /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. |
| pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> { |
| assert!(self.is_contiguous()); |
| assert!(self.dtype.is_type::<T>()); |
| unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() } |
| } |
| |
| /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. |
| pub fn is_contiguous(&self) -> bool { |
| match self.strides { |
| None => true, |
| Some(ref strides) => { |
| // check that stride for each dimension is the |
| // product of all trailing dimensons' shapes |
| self.shape |
| .iter() |
| .zip(strides) |
| .rfold( |
| (true, 1), |
| |(is_contig, expected_stride), (shape, stride)| { |
| ( |
| is_contig && *stride == expected_stride, |
| expected_stride * (*shape as usize), |
| ) |
| }, |
| ) |
| .0 |
| } |
| } |
| } |
| |
| /// Returns a clone of this `Tensor`. |
| /// |
| /// # Panics |
| /// |
| /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. |
| pub fn copy(&mut self, other: &Tensor) { |
| assert!( |
| self.dtype == other.dtype && self.size == other.size, |
| "Tensor shape/dtype mismatch." |
| ); |
| assert!( |
| self.is_contiguous() && other.is_contiguous(), |
| "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", |
| self.strides, |
| other.strides |
| ); |
| unsafe { |
| self.data |
| .as_mut_ptr() |
| .offset(self.byte_offset as isize) |
| .copy_from_nonoverlapping( |
| other.data.as_mut_ptr().offset(other.byte_offset), |
| other.size * other.dtype.itemsize(), |
| ); |
| } |
| } |
| |
| /// Returns an owned version of this `Tensor` via cloning. |
| pub fn to_owned(&self) -> Tensor<'static> { |
| let t = Tensor { |
| data: self.data.to_owned(), |
| ctx: self.ctx.clone(), |
| dtype: self.dtype.clone(), |
| size: self.size.clone(), |
| shape: self.shape.clone(), |
| strides: None, |
| byte_offset: 0, |
| }; |
| unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) } |
| } |
| |
| fn from_array_storage<'s, T, D: ndarray::Dimension>( |
| arr: &ndarray::Array<T, D>, |
| storage: Storage<'s>, |
| type_code: usize, |
| ) -> Tensor<'s> { |
| let type_width = mem::size_of::<T>() as usize; |
| Tensor { |
| data: storage, |
| ctx: TVMContext::default(), |
| dtype: DataType { |
| code: type_code, |
| bits: 8 * type_width, |
| lanes: 1, |
| }, |
| size: arr.len(), |
| shape: arr.shape().iter().map(|&v| v as i64).collect(), |
| strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), |
| byte_offset: 0, |
| } |
| } |
| |
| pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor { |
| assert!(!flatten || self.is_contiguous()); |
| DLTensor { |
| data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void, |
| ctx: DLContext::from(&self.ctx), |
| ndim: if flatten { 1 } else { self.shape.len() } as i32, |
| dtype: DLDataType::from(&self.dtype), |
| shape: if flatten { |
| &self.size as *const _ as *mut i64 |
| } else { |
| self.shape.as_ptr() |
| } as *mut i64, |
| strides: if flatten || self.is_contiguous() { |
| ptr::null_mut() |
| } else { |
| self.strides.as_ref().unwrap().as_ptr() |
| } as *mut i64, |
| byte_offset: 0, |
| } |
| } |
| } |
| |
| /// Conversions to `ndarray::Array` from `Tensor`, if the types match. |
| macro_rules! impl_ndarray_try_from_tensor { |
| ($type:ty, $dtype:expr) => { |
| impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { |
| type Error = Error; |
| fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> { |
| ensure!( |
| tensor.dtype == $dtype, |
| "Cannot convert Tensor with dtype {:?} to ndarray", |
| tensor.dtype |
| ); |
| Ok(ndarray::Array::from_shape_vec( |
| tensor |
| .shape |
| .iter() |
| .map(|s| *s as usize) |
| .collect::<Vec<usize>>(), |
| tensor.to_vec::<$type>(), |
| )?) |
| } |
| } |
| }; |
| } |
| |
| macro_rules! make_dtype_const { |
| ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { |
| pub const $name: DataType = DataType { |
| code: $code as usize, |
| bits: $bits, |
| lanes: $lanes, |
| }; |
| }; |
| } |
| |
| make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); |
| make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); |
| // make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); |
| make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); |
| make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); |
| impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); |
| impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); |
| impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); |
| impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); |
| |
| impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { |
| fn from(tensor: &'a Tensor<'t>) -> Self { |
| Tensor::as_dltensor(tensor, false /* flatten */) |
| } |
| } |
| |
| impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { |
| fn from(tensor: &'a mut Tensor<'t>) -> Self { |
| Tensor::as_dltensor(tensor, false /* flatten */) |
| } |
| } |
| |
| impl<'a> From<DLTensor> for Tensor<'a> { |
| fn from(dlt: DLTensor) -> Self { |
| unsafe { |
| let dtype = DataType::from(dlt.dtype); |
| let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec(); |
| let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize; |
| let storage = Storage::from(slice::from_raw_parts( |
| dlt.data as *const u8, |
| dtype.itemsize() * size, |
| )); |
| Self { |
| data: storage, |
| ctx: TVMContext::default(), |
| dtype: dtype, |
| size: size, |
| shape: shape, |
| strides: if dlt.strides == ptr::null_mut() { |
| None |
| } else { |
| Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) |
| }, |
| byte_offset: dlt.byte_offset as isize, |
| } |
| } |
| } |
| } |
| |
| /// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. |
| /// |
| /// # Panics |
| /// |
| /// Panics if the ndarray is not contiguous. |
| macro_rules! impl_tensor_from_ndarray { |
| ($type:ty, $typecode:expr) => { |
| impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> { |
| fn from(arr: ndarray::Array<$type, D>) -> Self { |
| let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous")); |
| Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize) |
| } |
| } |
| impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { |
| fn from(arr: &'a ndarray::Array<$type, D>) -> Self { |
| let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous")); |
| Tensor::from_array_storage(arr, storage, $typecode as usize) |
| } |
| } |
| }; |
| } |
| |
| impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); |
| impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); |
| impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); |
| impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); |
| impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); |
| impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); |