blob: c5eaea13661e92dd92a71071d723b8d339f96d5a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file native_op-inl.h
* \brief
* \author Junyuan Xie
*/
#ifndef MXNET_OPERATOR_CUSTOM_CUSTOM_INL_H_
#define MXNET_OPERATOR_CUSTOM_CUSTOM_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/c_api.h>
#include <mxnet/imperative.h>
#include <algorithm>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include <sstream>
#include <thread>
#include <mutex>
#include <functional>
#include <condition_variable>
#include <queue>
#include "../operator_common.h"
namespace mxnet {
namespace op {
namespace custom {
class CustomOperator {
public:
void Register(const std::string &op_type, CustomOpPropCreator creator) {
std::lock_guard<std::mutex> lock(mutex_);
if (registry_.find(op_type) != registry_.end()) {
LOG(WARNING) << "New registration is overriding existing custom operator " << op_type;
}
registry_[op_type] = creator;
}
CustomOpPropCreator Find(const std::string &op_type) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = registry_.find(op_type);
if (it != registry_.end()) return it->second;
return nullptr;
}
// For sparse the memory allocation is done during execution of operator
// which leads to changing of the pointers stored by ndarray chunk.
// Thus the changes to the copied ndarries don't propage to final
// inputs and outputs unlike the dense case. Passing vector of inputs and
// outputs ndarrays as args and updating the inputs and outputs ndarray
// chunk pointers to be same as the copied ndarrays.
template <typename Func>
void Push(const Func& func, const OpContext& ctx, bool recording,
bool training, const std::vector<NDArray>& arrs,
const std::vector<int>& tags,
const std::unordered_set<int>& output_tags,
const std::vector<NDArray>& outputs) {
if (naive_engine_) {
func();
for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
continue;
if (output_tags.count(tags[i]) > 0) {
outputs[out_idx].SparseUpdateChunk(arrs[i]);
out_idx++;
}
}
ctx.async_on_complete();
return;
}
std::unique_lock<std::mutex> lock(mutex_);
q_.push([=]() mutable {
bool prev_recording = Imperative::Get()->set_is_recording(recording);
bool prev_training = Imperative::Get()->set_is_training(training);
func();
Imperative::Get()->set_is_training(prev_training);
Imperative::Get()->set_is_recording(prev_recording);
std::vector<Engine::VarHandle> vars, vars2;
size_t idx = 0;
for (const auto& i : arrs) {
vars.push_back(i.var());
if (output_tags.count(tags[idx]) > 0) {
if (i.storage_type() == kDefaultStorage ||
i.storage_type() == kUndefinedStorage)
continue;
vars2.push_back(i.var());
idx++;
}
}
Engine::Get()->PushSync(
[=](RunContext rctx) {
for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
continue;
if (output_tags.count(tags[i]) > 0) {
outputs[out_idx].SparseUpdateChunk(arrs[i]);
out_idx++;
}
}
ctx.async_on_complete();
},
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
"CustomOperator");
});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads)
CreateThreads(q_.size() - num_free_threads);
cv_.notify_all();
}
static CustomOperator* Get() {
static CustomOperator inst;
return &inst;
}
void Start() {
num_free_threads = 0;
destructing_ = false;
naive_engine_ = true;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
}
}
void Stop() {
if (naive_engine_) return;
{
std::unique_lock<std::mutex> lock(mutex_);
destructing_ = true;
cv_.notify_all();
}
for (auto &worker : workers_)
worker.join();
workers_.clear();
}
private:
CustomOperator() {
this->Start();
}
void ThreadTarget() {
std::unique_lock<std::mutex> lock(mutex_);
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
--num_free_threads;
auto fn = q_.front();
q_.pop();
lock.unlock();
fn();
++num_free_threads;
lock.lock();
}
}
}
void SetNumThreads(int num_threads) {
num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads);
for (int i = workers_.size(); i < num_threads; ++i) {
workers_.emplace_back(std::thread([this]{this->ThreadTarget();}));
++num_free_threads;
}
}
void CreateThreads(int num_new_threads) {
SetNumThreads(workers_.size() + num_new_threads);
}
std::mutex mutex_;
std::map<std::string, CustomOpPropCreator> registry_;
// async worker
std::condition_variable cv_;
std::vector<std::thread> workers_;
std::atomic<uint32_t> num_free_threads;
std::queue<std::function<void(void)> > q_;
bool naive_engine_;
bool destructing_;
};
} // namespace custom
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CUSTOM_CUSTOM_INL_H_