blob: 4aebd08a6efb5dba293a27f4a40d9d1014462994 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file threaded_engine.h
* \brief Implements base class of threaded engine
* that tracks the dependency and pushes actions to execute.
* \author Yutian Li
*/
#ifndef MXNET_ENGINE_THREADED_ENGINE_H_
#define MXNET_ENGINE_THREADED_ENGINE_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/omp.h>
#include <mxnet/storage.h>
#include <vector>
#include <functional>
#include <condition_variable>
#include <atomic>
#include <utility>
#include <mutex>
#include <string>
#include <thread>
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/object_pool.h"
#include "../profiler/custom_op_profiler.h"
namespace mxnet {
namespace engine {
// Define helper macros for debug information.
#if ENGINE_DEBUG
#define DEFINE_ENGINE_DEBUG_INFO(Type) \
static std::atomic<std::size_t> counter; \
Type() { \
LOG(INFO) << __func__ << " " << ++counter; \
} \
~Type() { \
LOG(INFO) << __func__ << " " << --counter; \
}
#else
#define DEFINE_ENGINE_DEBUG_INFO(Type)
#endif
// Forward declarations
struct ThreadedOpr;
/*! shared_ptr to exception_ptr, used for exception handling */
typedef std::shared_ptr<std::exception_ptr> ExceptionRef;
/*!
* \brief Operation block in the scheduler.
* Each OprBlock corresponds to an operation pushed to the engine.
*/
struct OprBlock : public common::ObjectPoolAllocatable<OprBlock> {
/*!
* \brief wait number of pending tasks this OprBlock is waiting for.
*/
std::atomic<int> wait{0};
/*! \brief Pointer to information on performing real operation */
ThreadedOpr* opr{nullptr};
/*! \brief The context this operator */
Context ctx;
/*! \brief priority of the function */
int priority;
/*! \brief indicate whether to profile this operator */
bool profiling{false};
/*! \brief operator execution statistics */
std::unique_ptr<profiler::ProfileOperator> opr_profile;
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(OprBlock);
/*!
* \brief call this function to decrease the wait counter.
* \return the wait counter after the decreasement.
*/
inline int decr_wait() {
// check invariant, avoid over trigger
const int ret = --wait;
CHECK_GE(ret, 0);
return ret;
}
}; // struct OprBlock
/*!
* \brief VersionedVarBlock that corresponding to a variable version.
* This is a basic unit of LinkedList in the ThreadedVar.
*/
struct VersionedVarBlock : public common::ObjectPoolAllocatable<VersionedVarBlock> {
/*! \brief next block in the LinkedList */
VersionedVarBlock* next{nullptr};
/*! \brief the operation this block triggers */
OprBlock* trigger{nullptr};
/*! \brief whether this operation is a write(mutate) operation. */
bool write{false};
/*! \brief define possible debug information */
DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock);
}; // struct VersionedVarBlock
/*!
* \brief Variable implementation.
* Each ThreadedVar is a linked list(queue) of operations to be performed.
*/
class ThreadedVar final : public Var, public common::ObjectPoolAllocatable<ThreadedVar> {
public:
/*!
* \brief constructor
* \param head head block of the LinkedList,
* need to be initialized with next==nullptr and trigger=nullptr.
*/
explicit ThreadedVar(VersionedVarBlock* head);
/*!
* \brief Schedule a read operation on this variable.
* If the opr_block can be runed right away,
* the wait counter of opr_block will be decreased.
* Otherwise, the opr_block will be added to waiting queue.
* \param opr_block The operation to be scheduled.
*/
inline void AppendReadDependency(OprBlock* opr_block);
/*!
* \brief Schedule a write operation on this variable.
* If the opr_block can be runed right away,
* the wait counter of opr_block will be decreased.
* Otherwise, the opr_block will be added to waiting queue.
* \param opr_block The operation to be scheduled.
*/
inline void AppendWriteDependency(OprBlock* opr_block);
/*!
* \brief A read operation is completed on this variable.
* This function may trigger subsequent waiting operations on this variable.
*
* \param dispatcher the function called to trigger the operation,
* when all of its dependencies are satiesfied.
* \tparam Dispatcher the function called to trigger an operation.
*/
template <typename Dispatcher>
inline void CompleteReadDependency(Dispatcher dispatcher);
/*!
* \brief A write operation is completed on this variable.
* This function may trigger subsequent waiting operations on this variable.
*
* \param dispatcher the function called to trigger the operation,
* when all of its dependencies are satiesfied.
* \tparam Dispatcher the function called to trigger an operation.
* \return to_delete, whether this Variable can be deleted after this functin.
*/
template <typename Dispatcher>
inline bool CompleteWriteDependency(Dispatcher dispatcher);
/*! \brief Mark this variable to be deleted. */
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
inline size_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
* \return a casted pointer.
*/
inline static ThreadedVar* CastFromBase(Var* ptr) {
return ptr->Cast<ThreadedVar>();
}
// code for debug.
#if ENGINE_DEBUG
static std::atomic<std::size_t> counter;
~ThreadedVar() {
LOG(INFO) << __func__ << " " << --counter;
}
#endif // ENGINE_DEBUG
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef var_exception;
private:
// TODO(hotpxl) change this to spinlock for faster runtime
// TODO(hotpxl) consider rename head
/*! \brief internal mutex of the ThreadedVar */
std::mutex mutex_;
/*!
* \brief number of pending reads operation in the variable.
* will be marked as -1 when there is a already triggered pending write.
*/
int num_pending_reads_{0};
/*!
* \brief Points to the last VersionedVarBlock in the queue.
* head_ always points to a empty VersionedVarBlock.
* So when we want to append an operation to the queue:
* 1) update head_->trigger to be new op
* 2) update head_->next to be a new VersionedVarBlock
* 3) move head to head->next.
*/
VersionedVarBlock* head_{nullptr};
/*!
* \brief The pointer to next write to perform.
* This pointer will only be updated when the write completes.
* This is actually the head(oldest operation) in the queue.
*/
VersionedVarBlock* pending_write_{nullptr};
/*!
* \brief If true, delete after operation completes.
*/
bool to_delete_{false};
/*! \brief special const on num_pending_reads_ to mark write being triggered */
static constexpr int kWriteTriggered = -1;
/*!
* \brief derived invariant of ready to ready, without lock.
* \return whether the current variable is ready to read.
*/
inline bool is_ready_to_read() const {
return pending_write_ == nullptr;
}
}; // struct ThreadedVar
/*!
* \brief Operator used in ThreadedEngine.
*/
struct ThreadedOpr final : public Opr, public common::ObjectPoolAllocatable<ThreadedOpr> {
/*! \brief The function to be invoked each time. */
Engine::AsyncFn fn;
/*! \brief The variable this operation will read from. */
std::vector<ThreadedVar*> const_vars;
/*! \brief The variable this operation will mutate. */
std::vector<ThreadedVar*> mutable_vars;
/*! \brief The property of the operator */
FnProperty prop;
/*! \brief The name of the operator */
std::string opr_name;
/*!
* \brief Whether this is an temporary operator
* that can be deleted right after the operation completed.
*/
bool temporary{false};
/*!
* \brief Whether this is a WaitForVar operation
*/
bool wait{false};
/*!
* \brief Cast a Opr pointer to ThreadedOpr pointer
* \param ptr pointer from base.
* \return a casted pointer.
*/
inline static ThreadedOpr* CastFromBase(Opr* ptr) {
return ptr->Cast<ThreadedOpr>();
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef opr_exception;
}; // struct ThreadedOpr
/*!
* \brief Base class of all ThreadedEngine.
* This class implements a thread safe version of engine.
* The engine tracks the dependencies, and will call PushToExecute
* to execute a specific task.
*
* Subclass can implement PushToExecute to design specific
* execution policy for the tasks.
*/
class ThreadedEngine : public Engine {
public:
// implementing all the functions from Engine.
ThreadedVar* NewVariable() override;
ThreadedOpr* NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr,
bool wait = false) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override;
void PushAsync(AsyncFn exec_fun,
Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr,
bool wait = false) override;
void PushSync(SyncFn exec_fn,
Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;
void Throw(VarHandle var) override;
void NotifyShutdown() override {
shutdown_phase_.store(true);
}
ThreadedEngine() {
engine_info_ = dmlc::GetEnv("MXNET_ENGINE_INFO", false);
objpool_opr_ref_ = common::ObjectPool<ThreadedOpr>::_GetSharedRef();
objpool_blk_ref_ = common::ObjectPool<OprBlock>::_GetSharedRef();
objpool_varblk_ref_ = common::ObjectPool<VersionedVarBlock>::_GetSharedRef();
objpool_var_ref_ = common::ObjectPool<ThreadedVar>::_GetSharedRef();
storage_ref_ = Storage::_GetSharedRef();
// Get a ref to the profiler so that it doesn't get killed before us
profiler::Profiler::Get(&profiler_);
}
~ThreadedEngine() {
{
std::unique_lock<std::mutex> lock{finished_m_};
kill_.store(true);
}
finished_cv_.notify_all();
}
protected:
/*!
* \brief Push the opr block to execution queue to be executed.
* This function is implemented by the corresponding subclass
* for specific policy.
*
* \param opr_block The operator block.
* \param pusher_thread whether the caller is the thread that calls push
*/
virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0;
/*!
* \brief Call this function to actually execute an opr_block
* This function also deletes the opr_block after execution.
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ExecuteOprBlock(RunContext run_ctx,
OprBlock* opr_block,
CallbackOnStart on_start,
CallbackOnComplete callback) {
ThreadedOpr* threaded_opr = opr_block->opr;
if (opr_block->profiling && threaded_opr->opr_name.size()) {
std::unique_ptr<profiler::ProfileOperator::Attributes> attrs;
if (profiler_->AggregateEnabled()) {
attrs.reset(new profiler::ProfileOperator::Attributes());
}
const Context& ctx = opr_block->ctx;
opr_block->opr_profile.reset(
new profiler::ProfileOperator(threaded_opr->opr_name.c_str(), attrs.release()));
opr_block->opr_profile->startForDevice(ctx.dev_type, ctx.dev_id);
}
const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
if (debug_info) {
LOG(INFO) << "ExecuteOprBlock " << opr_block << "shutdown_phase=" << shutdown_phase_;
}
// still run cleanup in shutdown_phase
if (!shutdown_phase_ || threaded_opr->prop == FnProperty::kDeleteVar) {
try {
OnStart(threaded_opr);
if (debug_info) {
LOG(INFO) << "ExecuteOprFn ";
}
try {
if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->prop == FnProperty::kNoSkip) ||
threaded_opr->wait) {
threaded_opr->fn(run_ctx, on_start, callback);
} else {
on_start();
callback();
}
} catch (const std::exception& e) {
on_start();
threaded_opr->opr_exception =
std::make_shared<std::exception_ptr>(std::current_exception());
callback();
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch (std::exception& e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos && !shutdown_phase_) {
LOG(FATAL) << e.what() << "\n"
<< "A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
}
}
} else {
on_start();
callback();
}
}
int bulk_size() const override {
const profiler::Profiler* prof = profiler::Profiler::Get();
return (prof && prof->AggregateRunning()) ? 0 : BulkStatusStore::Get()->bulk_size;
}
int set_bulk_size(int bulk_size) override {
BulkStatus& bulk_status = *BulkStatusStore::Get();
std::swap(bulk_status.bulk_size, bulk_size);
if (bulk_status.count >= bulk_status.bulk_size)
BulkFlush();
if (!bulk_status.functions) {
bulk_status.functions.reset(new std::vector<SyncFn>());
}
bulk_status.functions->reserve(bulk_size);
return bulk_size;
}
protected:
static void OnStartStatic(Engine* engine, void* opr_block, const dmlc::Error* error);
static void OnCompleteStatic(Engine* engine, void* threaded_opr, const dmlc::Error* error);
#if MXNET_USE_CUDA
static void OnStartCPU(Engine* engine, void* opr_block, const dmlc::Error* error);
static void OnStartGPU(Engine* engine, void* sync_info, const dmlc::Error* error);
static void OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::Error* error);
struct GPUWorkerSyncInfo : public common::ObjectPoolAllocatable<GPUWorkerSyncInfo> {
void* opr_block{nullptr};
void* stream{nullptr};
void* event_pool{nullptr};
};
std::shared_ptr<common::ObjectPool<GPUWorkerSyncInfo>> objpool_gpu_sync_ref_;
#endif
private:
/*! \brief structure for holding bulk execution status */
struct BulkStatus {
/*! \brief maximum number of ops per bulk */
int bulk_size = 0;
/*! \brief current number of ops in bulk */
int count = 0;
/*! \brief context of current ops */
Context ctx;
/*! \brief current op functions */
std::shared_ptr<std::vector<SyncFn>> functions;
/*! \brief constant variables */
std::vector<VarHandle> const_vars;
/*! \brief mutable variables */
std::vector<VarHandle> mutable_vars;
};
/*! thread local store for bulk */
typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
* \param mutable_vars the variables to mutate.
*/
void CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars);
/*!
* \brief Callback on operation completion.
*
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
/*!
* \brief rethrow caught exception in WaitForVar
* \param threaded_var the var that we are waiting to read
*/
inline void ThrowException(ThreadedVar* threaded_var);
/*!
* \brief Mark exceptions before operation execution.
*
* Will mark the operator as a failure and associate exception_ptr
* if any of the read dependencies have exception associated.
*/
inline void OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
}
}
/*!
* \brief find exception in global_exception_refs and add it if missing
* \param opr_exception the exception to be added to global_exception_refs
*/
inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) {
auto it =
std::find(global_exception_refs_.begin(), global_exception_refs_.end(), opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(opr_exception);
}
return;
}
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn,
Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) {
BulkStatus& bulk_status = *BulkStatusStore::Get();
if (!bulk_status.functions) {
bulk_status.functions.reset(new std::vector<SyncFn>());
}
bulk_status.functions->push_back(exec_fn);
if (!bulk_status.count) {
bulk_status.ctx = exec_ctx;
}
++bulk_status.count;
bulk_status.const_vars.insert(
bulk_status.const_vars.end(), const_vars.begin(), const_vars.end());
bulk_status.mutable_vars.insert(
bulk_status.mutable_vars.end(), mutable_vars.begin(), mutable_vars.end());
if (bulk_status.count >= bulk_status.bulk_size)
BulkFlush();
}
/*! \brief flush current bulk to execution */
inline void BulkFlush() {
BulkStatus& bulk_status = *BulkStatusStore::Get();
if (!bulk_status.count)
return;
bulk_status.count = 0;
DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
auto functions = bulk_status.functions;
this->PushAsync(
[functions](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
on_start();
for (auto& fn : *functions) {
fn(ctx);
}
on_complete();
},
bulk_status.ctx,
bulk_status.const_vars,
bulk_status.mutable_vars,
FnProperty::kNormal,
0,
"ImperativeBulk");
bulk_status.functions.reset(new std::vector<SyncFn>());
bulk_status.functions->reserve(bulk_status.bulk_size);
bulk_status.const_vars.clear();
bulk_status.mutable_vars.clear();
}
/*!
* \brief Number of pending operations.
*/
std::atomic<int> pending_{0};
/*! \brief whether we want to kill the waiters */
std::atomic<bool> kill_{false};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
/*!\brief show more information from engine actions */
bool engine_info_{false};
/*! \brief debug information about wait for var. */
std::atomic<ThreadedVar*> debug_wait_var_{nullptr};
/*! \brief debug information about wait for var. */
std::atomic<OprBlock*> debug_push_opr_{nullptr};
/*!
* \brief Mutex and condition_variable,
* used to Notify waits for single or all variables.
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief global exception refs, which are rethrown when WaitForAll is called */
std::vector<ExceptionRef> global_exception_refs_;
/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
* See also #309 (https://github.com/apache/mxnet/issues/309)
*/
std::shared_ptr<common::ObjectPool<ThreadedOpr>> objpool_opr_ref_;
std::shared_ptr<common::ObjectPool<OprBlock>> objpool_blk_ref_;
std::shared_ptr<common::ObjectPool<VersionedVarBlock>> objpool_varblk_ref_;
std::shared_ptr<common::ObjectPool<ThreadedVar>> objpool_var_ref_;
/*!
* \brief Async destruction of some objects is relied on storage,
* prevent it from being destructed too early
*/
std::shared_ptr<Storage> storage_ref_;
#if MXNET_USE_CUDA
/*! \brief Number of GPU devices available */
std::atomic<int> device_count_{-1};
#endif
/*! \brief Hold a ref count ot the profiler */
std::shared_ptr<profiler::Profiler> profiler_;
/*!
* \brief Disallow copy construction and assignment.
* \note This must be last
*/
DISALLOW_COPY_AND_ASSIGN(ThreadedEngine);
}; // class ThreadedEngine
} // namespace engine
} // namespace mxnet
#endif // MXNET_ENGINE_THREADED_ENGINE_H_