blob: c6eb99508e092bb0dcb36072db5ffa9eee624e60 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file threaded_engine_pooled.cc
* \brief Pooled threaded engine
* \author Yutian Li
*/
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/concurrency.h>
#include <cassert>
#include <utility>
#include "./threaded_engine.h"
#include "./thread_pool.h"
#include "./stream_manager.h"
#if MXNET_USE_CUDA
#include "../common/cuda_utils.h"
#endif
namespace mxnet {
namespace engine {
/*!
* \brief ThreadedEngine using global thread pool across all devices.
* The policy of this Engine:
* - Execute Async operation immediately if pushed from Pusher.
* - Use a common thread pool for normal operations on all devices.
* - Use special thread pool for copy operations.
*/
class ThreadedEnginePooled : public ThreadedEngine {
public:
ThreadedEnginePooled() {
this->Start();
}
~ThreadedEnginePooled() noexcept(false) {
StopNoWait();
}
void StopNoWait() {
streams_->Finalize();
task_queue_->SignalForKill();
io_task_queue_->SignalForKill();
task_queue_ = nullptr;
io_task_queue_ = nullptr;
thread_pool_ = nullptr;
io_thread_pool_ = nullptr;
streams_ = nullptr;
}
void Stop() override {
WaitForAll();
StopNoWait();
}
void Start() override {
streams_.reset(new StreamManager<kMaxNumGpus, kNumStreamsPerGpu>());
task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>());
io_task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>());
thread_pool_.reset(new ThreadPool(kNumWorkingThreads,
[this](std::shared_ptr<dmlc::ManualEvent> ready_event) {
ThreadWorker(task_queue_, ready_event); },
true));
io_thread_pool_.reset(new ThreadPool(1,
[this](std::shared_ptr<dmlc::ManualEvent> ready_event) {
ThreadWorker(io_task_queue_, ready_event); },
true));
}
protected:
void PushToExecute(OprBlock *opr_block, bool pusher_thread) override {
if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) {
DoExecute(opr_block);
} else {
DoPushToQueue(opr_block);
}
}
private:
/*! \brief Concurrency for thread pool */
static constexpr std::size_t kNumWorkingThreads = 16;
/*! \brief Maximum number of GPUs */
static constexpr std::size_t kMaxNumGpus = 16;
/*!\brief number of streams allocated for each GPU */
static constexpr std::size_t kNumStreamsPerGpu = 16;
/*!
* \brief Streams.
*/
std::unique_ptr<StreamManager<kMaxNumGpus, kNumStreamsPerGpu>> streams_;
/*!
* \brief Task queues.
*/
std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue_;
std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> io_task_queue_;
/*!
* \brief Thread pools.
*/
std::unique_ptr<ThreadPool> thread_pool_;
std::unique_ptr<ThreadPool> io_thread_pool_;
/*!
* \brief Worker.
* \param task_queue Queue to work on.
*
* The method to pass to thread pool to parallelize.
*/
void ThreadWorker(std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue,
const std::shared_ptr<dmlc::ManualEvent>& ready_event) {
OprBlock* opr_block;
ready_event->signal();
while (task_queue->Pop(&opr_block)) {
DoExecute(opr_block);
}
}
/*!
* \brief Execute an operation.
* \param opr_block The operator block.
*/
void DoExecute(OprBlock* opr_block) {
#if MXNET_USE_CUDA
mxnet::common::cuda::DeviceStore device_store(-1, false);
#endif
assert(opr_block->wait.load() == 0);
if (opr_block->ctx.dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
device_store.SetDevice(opr_block->ctx.dev_id);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
}
bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU ||
opr_block->opr->prop == FnProperty::kCopyToGPU);
auto&& rctx = is_copy
? streams_->GetIORunContext(opr_block->ctx)
: streams_->GetRunContext(opr_block->ctx);
this->ExecuteOprBlock(rctx, opr_block);
}
/*!
* \brief Push the operation to the queue.
* \param opr_block The operator block.
*/
void DoPushToQueue(OprBlock* opr_block) {
switch (opr_block->opr->prop) {
case FnProperty::kCopyFromGPU:
case FnProperty::kCopyToGPU: {
io_task_queue_->Push(opr_block);
break;
}
default: {
task_queue_->Push(opr_block);
break;
}
}
}
};
Engine *CreateThreadedEnginePooled() {
return new ThreadedEnginePooled();
}
} // namespace engine
} // namespace mxnet