blob: e4b7e453b6e8017f34485f9350dd851b3754a117 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file subgraph_lib.cc
* \brief subgraph operator implementation library file
*/
#include <cmath>
#include <iostream>
#include <algorithm>
#include <utility>
#include "mxnet/lib_api.h"
using namespace mxnet::ext;
/* function to execute log operator on floats */
void myLog(MXTensor* in, MXTensor* out) {
float* inp = in->data<float>();
float* outp = out->data<float>();
for (int64_t i = 0; i < in->size(); i++) {
outp[i] = logf(inp[i]);
}
}
/* function to execute exp operator on floats */
void myExp(MXTensor* in, MXTensor* out) {
float* inp = in->data<float>();
float* outp = out->data<float>();
for (int64_t i = 0; i < in->size(); i++) {
outp[i] = expf(inp[i]);
}
}
/* function to execute ops in subgraph
* In MXNet, subgraphs are sorted in topological order
* so all we need to do is go through the ops in order
* and execute each op.
*/
MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
mxnet::ext::Graph* subgraph) {
std::cout << "Info: subgraph is: " << std::endl;
subgraph->print();
// counter for inputs
int input_cnt = 0;
// temporary tensor storage
std::vector<MXTensor> data;
// track memory allocations to free later
std::vector<void*> to_free;
// loop over nodes
for (int i = 0; i < subgraph->size(); i++) {
mxnet::ext::Node* node = subgraph->getNode(i);
// handle each op type
if (node->op.compare("null") == 0) {
// set tensor for this input to the subgraph
node->tensor = &inputs->at(input_cnt++);
} else if (node->op.compare("log") == 0) {
// get input tensor based on node ID inputs from data storage
MXTensor* input = node->inputs.at(0).node->tensor;
// create temporary storage
MXTensor tmp(malloc(input->size() * 4),
input->shape,
input->dtype,
0,
MXContext::CPU(0),
kDefaultStorage); // NOLINT
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute log operator
myLog(input, &tmp);
// add output tensor to data storage
data.push_back(tmp);
// set tensor for this node so we can read it later
node->tensor = &data.back();
} else if (node->op.compare("exp") == 0) {
// get input tensor based on node ID inputs from data storage
MXTensor* input = node->inputs.at(0).node->tensor;
// create temporary storage
MXTensor tmp(malloc(input->size() * 4),
input->shape,
input->dtype,
0,
MXContext::CPU(0),
kDefaultStorage); // NOLINT
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute exp operator
myExp(input, &tmp);
// add output tensor to data storage
data.push_back(tmp);
// set tensor for this node so we can read it later
node->tensor = &data.back();
} else {
MX_ERROR_MSG << "Error! Unsupported op '" << node->op << "' found in myExecutor";
// free allocated temporary storage
for (void* ptr : to_free)
free(ptr); // NOLINT
return MX_FAIL;
}
}
// copy all operator results to outputs of subgraph
for (int j = 0; j < subgraph->outputs.size(); j++) {
// get computed result
MXTensor* result = subgraph->outputs[j].node->tensor;
// get output tensor to pass to MX
MXTensor& out = outputs->at(j);
float* out_data = out.data<float>();
float* res_data = result->data<float>();
// loop and copy data
for (int64_t i = 0; i < result->size(); i++) {
out_data[i] = res_data[i];
}
}
// free allocated temporary storage
for (void* ptr : to_free) {
free(ptr); // NOLINT
}
return MX_SUCCESS;
}
class MyStatefulOp : public CustomStatefulOp {
public:
explicit MyStatefulOp(std::string json, const std::unordered_map<std::string, std::string>& attrs)
: attrs_(attrs) {
for (const auto& kv : attrs) {
std::cout << "subgraphOp attributes: " << kv.first << " ==> " << kv.second << std::endl;
}
subgraph_ = mxnet::ext::Graph::fromString(json);
}
MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) override {
if (attrs_.count(MX_STR_EXTRA_INPUTS) > 0 && std::stoi(attrs_.at(MX_STR_EXTRA_INPUTS)) > 0)
std::cout << "forward::extra_inputs(" << attrs_.at(MX_STR_EXTRA_INPUTS) << ")::inputs ["
<< inputs->size() << "]" << std::endl;
return myExecutor(inputs, outputs, subgraph_);
}
private:
mxnet::ext::Graph* subgraph_;
const std::unordered_map<std::string, std::string> attrs_;
};
MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
std::string serialized_subgraph = "[empty]";
// MXNet subgraph is stored as Symbol in operator node attrs subgraphs field
// custom subgraph is stored as json string in custom operator attrs map entry
if (attrs.count(MX_STR_SUBGRAPH_SYM_JSON)) {
// user can now parse json and run other custom ops inside subgraph
serialized_subgraph = attrs.at(MX_STR_SUBGRAPH_SYM_JSON);
}
*op_inst = new MyStatefulOp(serialized_subgraph, attrs);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}
REGISTER_OP(_custom_subgraph_op).setIsSubgraphOp().setCreateOpState(createOpState, "cpu");
const std::vector<std::string> op_names({"exp", "log"});
MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
std::vector<int>* ids,
const std::unordered_map<std::string, std::string>& options) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
// loop over nodes
for (int i = 0; i < graph->size(); i++) {
const mxnet::ext::Node* node = graph->getNode(i);
// get shape/type if available
std::string shape;
int dtype = -1;
if (node->attrs.count("shape") > 0)
shape = node->attrs.at("shape");
if (node->attrs.count("dtype") > 0)
dtype = std::stoi(node->attrs.at("dtype"));
// check if op dtype is float, and if option was specified to require float types
if ((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) {
// check if op is in allowlist
if (std::find(op_names.begin(), op_names.end(), node->op.c_str()) != op_names.end()) {
// found op in allowlist, set value to -1 to include op in any subgraph
ids->at(i) = -1;
}
}
}
return MX_SUCCESS;
}
MXReturnValue myReviewSubgraph(const mxnet::ext::Graph* subgraph,
int subgraph_id,
bool* accept,
const std::unordered_map<std::string, std::string>& options,
std::unordered_map<std::string, std::string>* attrs) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
std::string sg = subgraph->toString();
std::cout << "subgraph " << subgraph_id << ": " << std::endl;
std::cout << sg << std::endl;
// check if option `reject` was specified, and if so check if value is 'True'
if (options.count("reject") > 0 && options.at("reject").compare("True") == 0) {
// if specified, reject the subgraph. this is only used for testing
*accept = false;
std::cout << "rejecting subgraph" << std::endl;
} else {
*accept = true;
std::cout << "accepting subgraph" << std::endl;
}
attrs->emplace("myKey", "myVal");
return MX_SUCCESS;
}
REGISTER_PARTITIONER(myProp)
.addStrategy("strategy1", "_custom_subgraph_op")
.setSupportedOps("strategy1", mySupportedOps)
.setReviewSubgraph("strategy1", myReviewSubgraph);
class MySelector : public CustomOpSelector {
public:
MySelector(const mxnet::ext::Graph* graph,
const std::unordered_map<std::string, std::string>& options)
: graph_(graph), options_(options) {
for (auto kv : options) {
std::cout << "selector options: " << kv.first << " ==> " << kv.second << std::endl;
}
}
bool chooseNode(int nodeID) {
const mxnet::ext::Node* node = graph_->getNode(nodeID);
// get shape/type if available
std::string shape;
int dtype = -1;
if (node->attrs.count("shape") > 0)
shape = node->attrs.at("shape");
if (node->attrs.count("dtype") > 0)
dtype = std::stoi(node->attrs.at("dtype"));
// check if op dtype is float, and if option was specified to require float types
if ((dtype == kFloat32 && options_.count("reqFloat") > 0) || options_.count("reqFloat") == 0) {
// check if op is in allowlist
if (std::find(op_names.begin(), op_names.end(), node->op.c_str()) != op_names.end()) {
// found op in allowlist, return true to include op subgraph
return true;
}
}
return false;
}
bool Select(int nodeID) override {
return chooseNode(nodeID);
}
bool SelectInput(int nodeID, int input_nodeID) override {
return chooseNode(input_nodeID);
}
bool SelectOutput(int nodeID, int output_nodeID) override {
return chooseNode(output_nodeID);
}
virtual void Filter(std::vector<int>& candidates, std::vector<int>& keep) {
keep.insert(keep.end(), candidates.begin(), candidates.end());
}
void Reset() override {}
private:
const mxnet::ext::Graph* graph_;
const std::unordered_map<std::string, std::string> options_;
};
MXReturnValue createSelector(const mxnet::ext::Graph* graph,
CustomOpSelector** sel_inst,
const std::unordered_map<std::string, std::string>& options) {
*sel_inst = new MySelector(graph, options);
std::cout << "Info: selector created" << std::endl;
return MX_SUCCESS;
}
REGISTER_PARTITIONER(mySelect)
.addStrategy("strategy1", "_custom_subgraph_op")
.setCreateSelector("strategy1", createSelector)
.setReviewSubgraph("strategy1", myReviewSubgraph);
/* \brief a basic pass that adds a new input for subgraph ops */
MXReturnValue addInputPass(mxnet::ext::Graph* graph,
const std::unordered_map<std::string, std::string>& options) {
// find node with '_custom_subgraph_op' op type
for (int i = 0; i < graph->size(); i++) {
mxnet::ext::Node* n = graph->getNode(i);
if (n->op.compare("_custom_subgraph_op") == 0) {
// set extra input
n->attrs[MX_STR_EXTRA_INPUTS] = std::to_string(1);
// create a new input Node
Node* input = graph->addNode(n->name + "_input", "null");
// set this node as an input in the graph
graph->inputs.push_back(input);
// connect new input to node
input->outputs.push_back({n, (int)(n->inputs.size())});
// connect node to new input
n->inputs.push_back({input, 0});
// add a corresponding tensor for this input
input->alloc_arg({1}, MXContext::CPU(0), kFloat32);
}
}
return MX_SUCCESS;
}
REGISTER_PASS(addInputPass).setBody(addInputPass);
MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
MX_ERROR_MSG << "MXNet version " << version << " not supported by custom library" << std::endl;
return MX_FAIL;
}
}