blob: d8e3f2c9528235fcd6b9c1e7940600f2252cac55 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file shape.h
* \brief definition of shape
* \author Chuntao Hong, Zhang Chen
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_SHAPE_H_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_SHAPE_H_
#include <istream>
#include <ostream>
#include <algorithm>
#include <vector>
#include "mxnet-cpp/base.h"
namespace mxnet {
namespace cpp {
/*!
* \brief dynamic shape class that can hold shape
* of arbirary dimension
*/
struct Shape {
public:
/*! \brief constructor */
Shape()
: ndim_(0),
num_heap_allocated_(0),
data_heap_(NULL) {}
/*!
* \brief constructor from a vector of index_t
* \param v the vector
*/
explicit Shape(const std::vector<index_t> &v)
: ndim_(v.size()) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
std::copy(v.begin(), v.end(), data_stack_);
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
std::copy(v.begin(), v.end(), data_heap_);
}
}
/*!
* \brief constructor one dimmension shape
* \param s1 size of the first dimmension
*/
explicit Shape(index_t s1)
: ndim_(1) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
data_stack_[0] = s1;
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
data_heap_[0] = s1;
}
}
/*!
* \brief constructor two dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
*/
Shape(index_t s1, index_t s2)
: ndim_(2) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
data_stack_[0] = s1;
data_stack_[1] = s2;
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
data_heap_[0] = s1;
data_heap_[1] = s2;
}
}
/*!
* \brief constructor three dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
* \param s3 size of the third dimmension
*/
Shape(index_t s1, index_t s2, index_t s3)
: ndim_(3) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
data_stack_[0] = s1;
data_stack_[1] = s2;
data_stack_[2] = s3;
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
data_heap_[0] = s1;
data_heap_[1] = s2;
data_heap_[2] = s3;
}
}
/*!
* \brief constructor four dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
* \param s3 size of the third dimmension
* \param s4 size of the fourth dimmension
*/
Shape(index_t s1, index_t s2, index_t s3, index_t s4)
: ndim_(4) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
data_stack_[0] = s1;
data_stack_[1] = s2;
data_stack_[2] = s3;
data_stack_[3] = s4;
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
data_heap_[0] = s1;
data_heap_[1] = s2;
data_heap_[2] = s3;
data_heap_[3] = s4;
}
}
/*!
* \brief constructor five dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
* \param s3 size of the third dimmension
* \param s4 size of the fourth dimmension
* \param s5 size of the fifth dimmension
*/
Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
: ndim_(5) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
data_stack_[0] = s1;
data_stack_[1] = s2;
data_stack_[2] = s3;
data_stack_[3] = s4;
data_stack_[4] = s5;
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
data_heap_[0] = s1;
data_heap_[1] = s2;
data_heap_[2] = s3;
data_heap_[3] = s4;
data_heap_[5] = s5;
}
}
/*!
* \brief constructor from Shape
* \param s the source shape
*/
Shape(const Shape &s)
: ndim_(s.ndim_) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
}
}
#if MSHADOW_IN_CXX11
/*!
* \brief move constructor from Shape
* \param s the source shape
*/
Shape(Shape &&s)
: ndim_(s.ndim_),
num_heap_allocated_(s.num_heap_allocated_),
data_heap_(s.data_heap_) {
if (ndim_ <= kStackCache) {
std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
}
// remove data heap space from s
s.data_heap_ = NULL;
}
#endif
/*! \brief destructor */
~Shape() {
// data_heap_ can be NULL
delete[] data_heap_;
}
/*!
* \brief copy shape from content betwen two iterators
* \param begin the beginning of iterator
* \param end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline void CopyFrom(RandomAccessIterator begin,
RandomAccessIterator end) {
this->SetDim(end - begin);
std::copy(begin, end, data());
}
/*!
* \brief assignment from shape
* \param shape source shape
* \return reference of self
*/
inline Shape &operator=(const Shape &shape) {
this->SetDim(shape.ndim_);
const index_t *src = shape.data();
std::copy(src, src + ndim_, data());
return *this;
}
/*!
* \brief assignment from vector
* \param shape source shape
* \return reference of self
*/
inline Shape &operator=(const std::vector<index_t> &shape) {
this->CopyFrom(shape.begin(), shape.end());
return *this;
}
/*! \return the data content of the shape */
inline const index_t *data() const {
return ndim_ <= kStackCache ? data_stack_ : data_heap_;
}
/*! \return the data content of the shape */
inline index_t *data() {
return ndim_ <= kStackCache ? data_stack_ : data_heap_;
}
/*! \brief return number of dimension of the tensor inside */
inline index_t ndim(void) const {
return ndim_;
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline index_t &operator[](index_t i) {
return data()[i];
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline const index_t &operator[](index_t i) const {
return data()[i];
}
/*! \brief total number of elements in the tensor */
inline size_t Size(void) const {
size_t size = 1;
const index_t *d = this->data();
for (index_t i = 0; i < ndim_; ++i) {
size *= d[i];
}
return size;
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
*/
inline bool operator==(const Shape &s) const {
if (ndim_ != s.ndim_) return false;
if (ndim_ <= kStackCache) {
for (index_t i = 0; i < ndim_; ++i) {
if (data_stack_[i] != s.data_stack_[i]) return false;
}
} else {
for (index_t i = 0; i < ndim_; ++i) {
if (data_heap_[i] != s.data_heap_[i]) return false;
}
}
return true;
}
/*!
* \return whether two shape not equals
* \param s the shape to compare against
*/
inline bool operator!=(const Shape &s) const {
return !(*this == s);
}
friend std::ostream &operator<<(std::ostream &os, const Shape &shape);
friend std::istream &operator>>(std::istream &is, Shape &shape);
private:
// the shape will be stored in data_stack_
// when dimension is smaller than kStackCache
// when it is bigger, it will be stored in data_heap_;
/*! \brief size of in stack space */
static const index_t kStackCache = 5;
/*! \brief number of dimnsion of the shape */
index_t ndim_;
/*! \brief number of cells allocated in data_heap_ */
index_t num_heap_allocated_;
/*! \brief in stack space used to store shape when it is small */
index_t data_stack_[kStackCache];
/*! \brief space to store shape when dimension is big*/
index_t *data_heap_;
/*!
* \brief internal function to set the dimension
* \param dim the dimension of the shape
*/
inline void SetDim(index_t dim) {
if (dim > kStackCache &&
dim > num_heap_allocated_) {
// data_heap_ can be NULL
delete[] data_heap_;
data_heap_ = new index_t[dim];
num_heap_allocated_ = dim;
}
ndim_ = dim;
}
};
/*!
* \brief allow string printing of the shape
* \param os the output stream
* \param shape the shape
* \return the ostream
*/
inline std::ostream &operator<<(std::ostream &os, const Shape &shape) {
os << '(';
for (index_t i = 0; i < shape.ndim(); ++i) {
if (i != 0) os << ',';
os << static_cast<int>(shape[i]); // Supports negative Shape 'special codes' for inferring
}
// python style tuple
if (shape.ndim() == 1) os << ',';
os << ')';
return os;
}
/*!
* \brief read shape from the istream
* \param is the input stream
* \param shape the shape
* \return the istream
*/
inline std::istream &operator>>(std::istream &is, Shape &shape) {
// get (
while (true) {
char ch = is.get();
if (ch == '(') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
index_t idx;
std::vector<index_t> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')') {
is.get(); break;
}
break;
}
if (ch == ')') break;
} else if (ch == ')') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
shape.CopyFrom(tmp.begin(), tmp.end());
return is;
}
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SHAPE_H_