blob: 886bc5b4566bac6800dd67b0ff9d9dc234737480 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lazy_alloc_array.h
* \brief An array that lazily allocate elements as
* First time the cell get visited.
*/
#ifndef MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
#define MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
#include <dmlc/logging.h>
#include <memory>
#include <mutex>
#include <array>
#include <vector>
#include <atomic>
namespace mxnet {
namespace common {
template <typename TElem>
class LazyAllocArray {
public:
LazyAllocArray();
/*!
* \brief Get element of corresponding index,
* if it is not created create by creator
* \param index the array index position
* \param creator a lambda function to create new element when needed.
*/
template <typename FCreate>
inline std::shared_ptr<TElem> Get(int index, FCreate creator);
/*!
* \brief for each not null element of the array, call fvisit
* \param fvisit a function of (size_t, TElem*)
*/
template <typename FVisit>
inline void ForEach(FVisit fvisit);
/*! \brief clear all the allocated elements in array */
inline void Clear();
private:
template <typename SyncObject>
class unique_unlock {
public:
explicit unique_unlock(std::unique_lock<SyncObject>* lock) : lock_(lock) {
if (lock_) {
lock_->unlock();
}
}
~unique_unlock() {
if (lock_) {
lock_->lock();
}
}
private:
std::unique_lock<SyncObject>* lock_;
};
/*! \brief the initial size of the array */
static constexpr std::size_t kInitSize = 16;
/*! \brief mutex used during creation */
std::mutex create_mutex_;
/*! \brief internal data fir initial size */
std::array<std::shared_ptr<TElem>, kInitSize> head_;
/*! \brief overflow array of more elements */
std::vector<std::shared_ptr<TElem> > more_;
/*! \brief Signal shutdown of array */
std::atomic<bool> is_clearing_;
};
template <typename TElem>
inline LazyAllocArray<TElem>::LazyAllocArray() : is_clearing_(false) {}
// implementations
template <typename TElem>
template <typename FCreate>
inline std::shared_ptr<TElem> LazyAllocArray<TElem>::Get(int index, FCreate creator) {
CHECK_GE(index, 0);
size_t idx = static_cast<size_t>(index);
if (idx < kInitSize) {
std::shared_ptr<TElem> ptr = head_[idx];
if (ptr) {
return ptr;
} else {
std::lock_guard<std::mutex> lock(create_mutex_);
if (!is_clearing_.load()) {
std::shared_ptr<TElem> ptr = head_[idx];
if (ptr) {
return ptr;
}
ptr = head_[idx] = std::shared_ptr<TElem>(creator());
return ptr;
}
}
} else {
std::lock_guard<std::mutex> lock(create_mutex_);
if (!is_clearing_.load()) {
idx -= kInitSize;
if (more_.size() <= idx) {
more_.reserve(idx + 1);
while (more_.size() <= idx) {
more_.push_back(std::shared_ptr<TElem>(nullptr));
}
}
std::shared_ptr<TElem> ptr = more_[idx];
if (ptr) {
return ptr;
}
ptr = more_[idx] = std::shared_ptr<TElem>(creator());
return ptr;
}
}
return nullptr;
}
template <typename TElem>
inline void LazyAllocArray<TElem>::Clear() {
std::unique_lock<std::mutex> lock(create_mutex_);
is_clearing_.store(true);
// Currently, head_ and more_ never get smaller, so it's safe to
// iterate them outside of the lock. The loops should catch
// any growth which might happen when create_mutex_ is unlocked
for (size_t i = 0; i < head_.size(); ++i) {
std::shared_ptr<TElem> p = head_[i];
head_[i] = std::shared_ptr<TElem>(nullptr);
unique_unlock<std::mutex> unlocker(&lock);
p = std::shared_ptr<TElem>(nullptr);
}
for (size_t i = 0; i < more_.size(); ++i) {
std::shared_ptr<TElem> p = more_[i];
more_[i] = std::shared_ptr<TElem>(nullptr);
unique_unlock<std::mutex> unlocker(&lock);
p = std::shared_ptr<TElem>(nullptr);
}
more_.clear();
is_clearing_.store(false);
}
template <typename TElem>
template <typename FVisit>
inline void LazyAllocArray<TElem>::ForEach(FVisit fvisit) {
std::lock_guard<std::mutex> lock(create_mutex_);
for (size_t i = 0; i < head_.size(); ++i) {
if (head_[i].get() != nullptr) {
fvisit(i, head_[i].get());
}
}
for (size_t i = 0; i < more_.size(); ++i) {
if (more_[i].get() != nullptr) {
fvisit(i + kInitSize, more_[i].get());
}
}
}
} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_