blob: 060f4734a675209c5c845fb5a469aa4d54cff424 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
*/
#ifndef MXNET_ENGINE_THREAD_POOL_H_
#define MXNET_ENGINE_THREAD_POOL_H_
#include <dmlc/base.h>
#include <cstddef>
#include <vector>
#include <list>
#include <thread>
#include <utility>
#include "mxnet/base.h"
namespace mxnet {
namespace engine {
/*!
* \brief Thread pool.
*/
class ThreadPool {
public:
/*! \brief Simple manually-signalled event gate which remains open */
class SimpleEvent {
public:
SimpleEvent()
: signaled_(false) {}
void wait() {
std::unique_lock<std::mutex> lock(mutex_);
if (!signaled_) {
condition_variable_.wait(lock);
}
}
void signal() {
signaled_ = true;
std::unique_lock<std::mutex> lk(mutex_);
condition_variable_.notify_all();
}
/*! \brief Signal event upon destruction, even for exceptions (RAII) */
struct SetReadyOnDestroy {
explicit inline SetReadyOnDestroy(std::shared_ptr<SimpleEvent> event)
: event_(event) {
}
inline ~SetReadyOnDestroy() {
if (event_) {
event_->signal();
}
}
std::shared_ptr<SimpleEvent> event_;
};
private:
std::mutex mutex_;
std::condition_variable condition_variable_;
std::atomic<bool> signaled_;
};
/*!
* \brief Constructor takes function to run.
* \param size size of the thread pool.
* \param func the function to run on the thread pool.
*/
explicit ThreadPool(size_t size, std::function<void()> func)
: worker_threads_(size) {
for (auto& i : worker_threads_) {
i = std::thread(func);
}
}
explicit ThreadPool(size_t size,
std::function<void(std::shared_ptr<SimpleEvent> ready)> func,
const bool wait)
: worker_threads_(size) {
for (auto& i : worker_threads_) {
std::shared_ptr<SimpleEvent> ptr = std::make_shared<SimpleEvent>();
ready_events_.emplace_back(ptr);
i = std::thread(func, ptr);
}
if (wait) {
WaitForReady();
}
}
~ThreadPool() noexcept(false) {
for (auto&& i : worker_threads_) {
i.join();
}
}
private:
/*!
* \brief Wait for all started threads to signal that they're ready
*/
void WaitForReady() {
for (std::shared_ptr<SimpleEvent> ptr : ready_events_) {
ptr->wait();
}
}
/*!
* \brief Worker threads.
*/
std::vector<std::thread> worker_threads_;
/*!
* \brief Startup synchronization objects
*/
std::list<std::shared_ptr<SimpleEvent>> ready_events_;
/*!
* \brief Disallow default construction.
*/
ThreadPool() = delete;
/*!
* \brief Disallow copy construction and assignment.
*/
DISALLOW_COPY_AND_ASSIGN(ThreadPool);
};
} // namespace engine
} // namespace mxnet
#endif // MXNET_ENGINE_THREAD_POOL_H_