blob: 3de7a8590fde081159d661c8fc44e93e9b4dd89d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
use failure::Error;
use ndarray;
use tvm_common::{
array::{DataType, TVMContext},
ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
DLDataTypeCode_kDLUInt, DLTensor,
},
};
use crate::allocator::Allocation;
/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
pub enum Storage<'a> {
/// A `Storage` which owns its contained bytes.
Owned(Allocation),
/// A view of an existing `Storage`.
View(&'a mut [u8], usize), // ptr, align
}
impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
pub fn as_mut_ptr(&self) -> *mut u8 {
match self {
Storage::Owned(alloc) => alloc.as_mut_ptr(),
Storage::View(slice, _) => slice.as_ptr() as *mut u8,
}
}
pub fn size(&self) -> usize {
match self {
Storage::Owned(alloc) => alloc.size(),
Storage::View(slice, _) => slice.len(),
}
}
pub fn align(&self) -> usize {
match self {
Storage::Owned(alloc) => alloc.align(),
Storage::View(_, align) => *align,
}
}
pub fn as_ptr(&self) -> *const u8 {
self.as_mut_ptr() as *const _
}
/// Returns a `Storage::View` which points to an owned `Storage::Owned`.
pub fn view(&self) -> Storage<'a> {
match self {
Storage::Owned(alloc) => Storage::View(
unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
self.align(),
),
Storage::View(slice, _) => Storage::View(
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
self.align(),
),
}
}
pub fn is_owned(&self) -> bool {
match self {
Storage::Owned(_) => true,
_ => false,
}
}
/// Returns an owned version of this storage via cloning.
pub fn to_owned(&self) -> Storage<'static> {
let s = Storage::new(self.size(), Some(self.align())).unwrap();
unsafe {
s.as_mut_ptr()
.copy_from_nonoverlapping(self.as_ptr(), self.size());
}
s
}
}
impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
fn from(data: &'d [T]) -> Self {
let data = unsafe {
slice::from_raw_parts_mut(
data.as_ptr() as *const u8 as *mut u8,
data.len() * mem::size_of::<T>() as usize,
)
};
Storage::View(data, mem::align_of::<T>())
}
}
/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
/// converted to `ndarray::Array` for non-TVM processing.
///
/// # Examples
///
/// ```
/// extern crate ndarray;
///
/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
/// let mut a: Tensor = a_nd.into();
/// let mut a_dl: DLTensor = (&mut t).into();
/// call_packed!(tvm_fn, &mut a_dl);
///
/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
/// ```
#[derive(PartialEq)]
pub struct Tensor<'a> {
/// The bytes which contain the data this `Tensor` represents.
pub(crate) data: Storage<'a>,
pub(crate) ctx: TVMContext,
pub(crate) dtype: DataType,
pub(crate) shape: Vec<i64>,
// ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
pub(crate) strides: Option<Vec<usize>>,
pub(crate) byte_offset: isize,
/// The number of elements in the `Tensor`.
pub(crate) size: usize,
}
unsafe impl<'a> Send for Tensor<'a> {}
impl<'a> Tensor<'a> {
pub fn shape(&self) -> Vec<i64> {
self.shape.clone()
}
/// Returns the data of this `Tensor` as a `Vec`.
///
/// # Panics
///
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
assert!(self.is_contiguous());
assert!(self.dtype.is_type::<T>());
unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
}
/// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
pub fn is_contiguous(&self) -> bool {
match self.strides {
None => true,
Some(ref strides) => {
// check that stride for each dimension is the
// product of all trailing dimensons' shapes
self.shape
.iter()
.zip(strides)
.rfold(
(true, 1),
|(is_contig, expected_stride), (shape, stride)| {
(
is_contig && *stride == expected_stride,
expected_stride * (*shape as usize),
)
},
)
.0
}
}
}
/// Returns a clone of this `Tensor`.
///
/// # Panics
///
/// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
pub fn copy(&mut self, other: &Tensor) {
assert!(
self.dtype == other.dtype && self.size == other.size,
"Tensor shape/dtype mismatch."
);
assert!(
self.is_contiguous() && other.is_contiguous(),
"copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
self.strides,
other.strides
);
unsafe {
self.data
.as_mut_ptr()
.offset(self.byte_offset as isize)
.copy_from_nonoverlapping(
other.data.as_mut_ptr().offset(other.byte_offset),
other.size * other.dtype.itemsize(),
);
}
}
/// Returns an owned version of this `Tensor` via cloning.
pub fn to_owned(&self) -> Tensor<'static> {
let t = Tensor {
data: self.data.to_owned(),
ctx: self.ctx.clone(),
dtype: self.dtype.clone(),
size: self.size.clone(),
shape: self.shape.clone(),
strides: None,
byte_offset: 0,
};
unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
}
fn from_array_storage<'s, T, D: ndarray::Dimension>(
arr: &ndarray::Array<T, D>,
storage: Storage<'s>,
type_code: usize,
) -> Tensor<'s> {
let type_width = mem::size_of::<T>() as usize;
Tensor {
data: storage,
ctx: TVMContext::default(),
dtype: DataType {
code: type_code,
bits: 8 * type_width,
lanes: 1,
},
size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
byte_offset: 0,
}
}
pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
assert!(!flatten || self.is_contiguous());
DLTensor {
data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
ctx: DLContext::from(&self.ctx),
ndim: if flatten { 1 } else { self.shape.len() } as i32,
dtype: DLDataType::from(&self.dtype),
shape: if flatten {
&self.size as *const _ as *mut i64
} else {
self.shape.as_ptr()
} as *mut i64,
strides: if flatten || self.is_contiguous() {
ptr::null_mut()
} else {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
}
}
}
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
tensor.dtype
);
Ok(ndarray::Array::from_shape_vec(
tensor
.shape
.iter()
.map(|s| *s as usize)
.collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
)?)
}
}
};
}
macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
pub const $name: DataType = DataType {
code: $code as usize,
bits: $bits,
lanes: $lanes,
};
};
}
make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
fn from(tensor: &'a Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false /* flatten */)
}
}
impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
fn from(tensor: &'a mut Tensor<'t>) -> Self {
Tensor::as_dltensor(tensor, false /* flatten */)
}
}
impl<'a> From<DLTensor> for Tensor<'a> {
fn from(dlt: DLTensor) -> Self {
unsafe {
let dtype = DataType::from(dlt.dtype);
let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
let storage = Storage::from(slice::from_raw_parts(
dlt.data as *const u8,
dtype.itemsize() * size,
));
Self {
data: storage,
ctx: TVMContext::default(),
dtype: dtype,
size: size,
shape: shape,
strides: if dlt.strides == ptr::null_mut() {
None
} else {
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
},
byte_offset: dlt.byte_offset as isize,
}
}
}
}
/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
///
/// # Panics
///
/// Panics if the ndarray is not contiguous.
macro_rules! impl_tensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
fn from(arr: ndarray::Array<$type, D>) -> Self {
let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize)
}
}
impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
Tensor::from_array_storage(arr, storage, $typecode as usize)
}
}
};
}
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);