blob: 7f892cba7d3da32bb4b34ac910340cddad8e4efd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_ENGINE_THREAD_POOL_H_
#define MXNET_ENGINE_THREAD_POOL_H_
#include <dmlc/base.h>
#include <dmlc/thread_group.h>
#include <cstddef>
#include <vector>
#include <list>
#include <thread>
#include <utility>
#include "mxnet/base.h"
namespace mxnet {
namespace engine {
/*!
* \brief Thread pool.
*/
class ThreadPool {
public:
/*! \brief Signal event upon destruction, even for exceptions (RAII) */
struct SetReadyOnDestroy {
explicit inline SetReadyOnDestroy(const std::shared_ptr<dmlc::ManualEvent>& event)
: event_(event) {}
inline ~SetReadyOnDestroy() {
if (event_) {
event_->signal();
}
}
std::shared_ptr<dmlc::ManualEvent> event_;
};
/*!
* \brief Constructor takes function to run.
* \param size size of the thread pool.
* \param func the function to run on the thread pool.
*/
explicit ThreadPool(size_t size, std::function<void()> func) : worker_threads_(size) {
CHECK_GT(size, 0);
for (auto& i : worker_threads_) {
i = std::thread(func);
}
}
explicit ThreadPool(size_t size,
std::function<void(std::shared_ptr<dmlc::ManualEvent> ready)> func,
const bool wait)
: worker_threads_(size) {
CHECK_GT(size, 0);
for (auto& i : worker_threads_) {
std::shared_ptr<dmlc::ManualEvent> ptr = std::make_shared<dmlc::ManualEvent>();
ready_events_.emplace_back(ptr);
i = std::thread(func, ptr);
}
if (wait) {
WaitForReady();
}
}
~ThreadPool() noexcept(false) {
for (auto&& i : worker_threads_) {
i.join();
}
}
private:
/*!
* \brief Wait for all started threads to signal that they're ready
*/
void WaitForReady() {
for (const std::shared_ptr<dmlc::ManualEvent>& ptr : ready_events_) {
ptr->wait();
}
}
/*!
* \brief Worker threads.
*/
std::vector<std::thread> worker_threads_;
/*!
* \brief Startup synchronization objects
*/
std::list<std::shared_ptr<dmlc::ManualEvent>> ready_events_;
/*!
* \brief Disallow default construction.
*/
ThreadPool() = delete;
/*!
* \brief Disallow copy construction and assignment.
*/
DISALLOW_COPY_AND_ASSIGN(ThreadPool);
};
} // namespace engine
} // namespace mxnet
#endif // MXNET_ENGINE_THREAD_POOL_H_