blob: 73c8f4807b77ab9df1400206558e5f4a21d4c64f [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file attach_op_resource_pass.cc
* \brief Pass to attach resource to OpExecVector of the graph.
*/
#include <mxnet/resource.h>
#include <mxnet/op_attr_types.h>
#include "./exec_pass.h"
namespace mxnet {
namespace exec {
Graph AttachOpResources(Graph g) {
auto& fresource =
nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
const auto& vctx = g.GetAttr<ContextVector>("context");
const auto& idx = g.indexed_graph();
// Use global resource pool for each executor for now.
std::map<Context, Resource> cached_temp;
// Resource allocation
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (fresource.count(inode.source->op()) == 0) continue;
auto reqs = fresource[inode.source->op()](inode.source->attrs);
auto& requested = op_execs[nid]->op_ctx.requested;
requested.clear();
// Get the resource of temporal space.
for (const ResourceRequest& req : reqs) {
const Context &ctx = vctx[nid];
if (req.type == ResourceRequest::kTempSpace) {
if (cached_temp.count(ctx) != 0) {
requested.push_back(cached_temp.at(ctx));
} else {
Resource r = ResourceManager::Get()->Request(ctx, req);
requested.push_back(r);
cached_temp[ctx] = r;
}
} else if (req.type == ResourceRequest::kRandom) {
requested.push_back(ResourceManager::Get()->Request(ctx, req));
} else {
LOG(FATAL) << "resource type not yet supported";
}
}
}
return g;
}
} // namespace exec
} // namespace mxnet