blob: 27b9853a74b74bc6a9aba8d07fd1a678a2a52e6a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file container.h
* \brief Array container
*/
// Acknowledgement: This file originates from incubator-tvm
#ifndef MXNET_NODE_CONTAINER_H_
#define MXNET_NODE_CONTAINER_H_
#include <mxnet/node/node.h>
#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
namespace mxnet {
/*! \brief array node content in array */
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;
static constexpr const char* _type_key = "Array";
MXNET_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};
/*!
* \brief iterator adapter that adapts TIter to return another type.
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template<typename Converter,
typename TIter>
class IterAdapter {
public:
using difference_type = typename std::iterator_traits<TIter>::difference_type;
using value_type = typename Converter::ResultType;
using pointer = typename Converter::ResultType*;
using reference = typename Converter::ResultType&; // NOLINT(*)
using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() {
++iter_;
return *this;
}
inline IterAdapter operator+(difference_type offset) const {
return IterAdapter(iter_ + offset);
}
template<typename T = IterAdapter>
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
typename T::difference_type>::type
inline operator-(const IterAdapter& rhs) const {
return iter_ - rhs.iter_;
}
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
inline bool operator!=(IterAdapter other) const {
return !(*this == other);
}
inline const value_type operator*() const {
return Converter::convert(*iter_);
}
private:
TIter iter_;
};
/*!
* \brief Array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam T The content NodeRef type.
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
class Array : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Array() {
data_ = make_object<ArrayNode>();
}
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T> && other) { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T> &other) { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Array(runtime::ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
*/
explicit Array(size_t n, const T& val) {
auto tmp_node = make_object<ArrayNode>();
for (size_t i = 0; i < n; ++i) {
tmp_node->data.push_back(val);
}
data_ = std::move(tmp_node);
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T> && other) {
data_ = std::move(other.data_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T> & other) {
data_ = other.data_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back(T(*it));
}
data_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
return DowncastNoCheck<T>(
static_cast<const ArrayNode*>(data_.get())->data[i]);
}
/*! \return The size of the array */
inline size_t size() const {
if (data_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(data_.get())->data.size();
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline ArrayNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
runtime::ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
n->data = static_cast<ArrayNode*>(data_.get())->data;
runtime::ObjectPtr<Object>(std::move(n)).swap(data_);
}
return static_cast<ArrayNode*>(data_.get());
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item);
}
/*!
* \brief Resize the array.
* \param size The new size.
*/
inline void resize(size_t size) {
ArrayNode* n = this->CopyOnWrite();
n->data.resize(size);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
ArrayNode* n = this->CopyOnWrite();
n->data[i] = value;
}
/*! \return whether array is empty */
inline bool empty() const {
return size() == 0;
}
/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template<typename F>
inline void MutateByApply(F fmutate) {
ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
if (ptr == nullptr) return;
if (data_.unique()) {
// Copy on write optimization.
// Perform inplace update because this is an unique copy.
for (size_t i = 0; i < ptr->data.size(); ++i) {
// It is important to use move here
// to make prevent the element's ref count from increasing
// so fmutate itself can perform copy-on-write optimization
T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
T new_elem = fmutate(std::move(old_elem));
ptr->data[i] = std::move(new_elem);
}
} else {
// lazily trigger copy if there is element change.
runtime::ObjectPtr<ArrayNode> copy;
for (size_t i = 0; i < ptr->data.size(); ++i) {
T old_elem = DowncastNoCheck<T>(ptr->data[i]);
T new_elem = fmutate(old_elem);
if (!new_elem.same_as(ptr->data[i])) {
// copy the old array
if (copy == nullptr) {
copy = runtime::make_object<ArrayNode>(*ptr);
}
copy->data[i] = std::move(new_elem);
}
}
// replace the data with the new copy.
if (copy != nullptr) {
data_ = std::move(copy);
}
}
}
/*! \brief specify container node */
using ContainerType = ArrayNode;
struct ValueConverter {
using ResultType = T;
static inline T convert(const ObjectRef& n) {
return DowncastNoCheck<T>(n);
}
};
using iterator = IterAdapter<ValueConverter,
std::vector<ObjectRef>::const_iterator>;
using reverse_iterator = IterAdapter<
ValueConverter,
std::vector<ObjectRef>::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const ArrayNode*>(data_.get())->data.end());
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend());
}
};
} // namespace mxnet
#endif // MXNET_NODE_CONTAINER_H_