blob: 9345b6a13ab4434e7eba3acbd584863381736045 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file subgraph_lib.cc
* \brief subgraph operator implementation library file
*/
#include <cmath>
#include <iostream>
#include <algorithm>
#include <utility>
#include "mxnet/lib_api.h"
using namespace mxnet::ext;
/* function to execute log operator on floats */
void myLog(MXTensor *in, MXTensor *out) {
float* inp = in->data<float>();
float* outp = out->data<float>();
for (int64_t i = 0; i < in->size(); i++) {
outp[i] = logf(inp[i]);
}
}
/* function to execute exp operator on floats */
void myExp(MXTensor *in, MXTensor *out) {
float* inp = in->data<float>();
float* outp =out->data<float>();
for (int64_t i = 0; i < in->size(); i++) {
outp[i] = expf(inp[i]);
}
}
/* function to execute ops in subgraph
* In MXNet, subgraphs are sorted in topological order
* so all we need to do is go through the ops in order
* and execute each op.
*/
MXReturnValue myExecutor(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
mxnet::ext::Graph *subgraph) {
std::cout << "Info: subgraph is: " << std::endl;
subgraph->print();
//counter for inputs
int input_cnt = 0;
// temporary tensor storage
std::vector<MXTensor> data;
// track memory allocations to free later
std::vector<void*> to_free;
// loop over nodes
for(int i=0; i<subgraph->size(); i++) {
mxnet::ext::Node* node = subgraph->getNode(i);
// handle each op type
if (node->op.compare("null") == 0) {
// set tensor for this input to the subgraph
node->tensor = &inputs->at(input_cnt++);
} else if (node->op.compare("log") == 0) {
// get input tensor based on node ID inputs from data storage
MXTensor *input = node->inputs.at(0).node->tensor;
// create temporary storage
MXTensor tmp(malloc(input->size()*4), input->shape, input->dtype, 0, MXContext::CPU(0), kDefaultStorage); // NOLINT
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute log operator
myLog(input,&tmp);
// add output tensor to data storage
data.push_back(tmp);
// set tensor for this node so we can read it later
node->tensor = &data.back();
} else if (node->op.compare("exp") == 0) {
// get input tensor based on node ID inputs from data storage
MXTensor *input = node->inputs.at(0).node->tensor;
// create temporary storage
MXTensor tmp(malloc(input->size()*4), input->shape, input->dtype, 0, MXContext::CPU(0), kDefaultStorage); // NOLINT
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute exp operator
myExp(input,&tmp);
// add output tensor to data storage
data.push_back(tmp);
// set tensor for this node so we can read it later
node->tensor = &data.back();
} else {
MX_ERROR_MSG << "Error! Unsupported op '" << node->op << "' found in myExecutor";
// free allocated temporary storage
for (void* ptr : to_free)
free(ptr); // NOLINT
return MX_FAIL;
}
}
// copy all operator results to outputs of subgraph
for (int j = 0; j < subgraph->outputs.size(); j++) {
// get computed result
MXTensor *result = subgraph->outputs[j].node->tensor;
// get output tensor to pass to MX
MXTensor &out = outputs->at(j);
float *out_data = out.data<float>();
float *res_data = result->data<float>();
// loop and copy data
for (int64_t i = 0; i < result->size(); i++) {
out_data[i] = res_data[i];
}
}
// free allocated temporary storage
for (void* ptr : to_free) {
free(ptr); // NOLINT
}
return MX_SUCCESS;
}
class MyStatefulOp : public CustomStatefulOp {
public:
explicit MyStatefulOp(std::string json,
const std::unordered_map<std::string, std::string>& attrs)
: attrs_(attrs) {
for (const auto &kv : attrs) {
std::cout << "subgraphOp attributes: " << kv.first << " ==> " << kv.second << std::endl;
}
subgraph_ = mxnet::ext::Graph::fromString(json);
}
MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) override {
if(attrs_.count(MX_STR_EXTRA_INPUTS) > 0 && std::stoi(attrs_.at(MX_STR_EXTRA_INPUTS)) > 0)
std::cout << "forward::extra_inputs(" << attrs_.at(MX_STR_EXTRA_INPUTS) << ")::inputs ["
<< inputs->size() << "]" << std::endl;
return myExecutor(inputs, outputs, subgraph_);
}
private:
mxnet::ext::Graph *subgraph_;
const std::unordered_map<std::string, std::string> attrs_;
};
MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
std::string serialized_subgraph = "[empty]";
// MXNet subgraph is stored as Symbol in operator node attrs subgraphs field
// custom subgraph is stored as json string in custom operator attrs map entry
if (attrs.count(MX_STR_SUBGRAPH_SYM_JSON)) {
// user can now parse json and run other custom ops inside subgraph
serialized_subgraph = attrs.at(MX_STR_SUBGRAPH_SYM_JSON);
}
*op_inst = new MyStatefulOp(serialized_subgraph, attrs);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}
REGISTER_OP(_custom_subgraph_op)
.setIsSubgraphOp()
.setCreateOpState(createOpState, "cpu");
const std::vector<std::string> op_names({"exp","log"});
MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
std::vector<int>* ids,
const std::unordered_map<std::string, std::string>& options) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
//loop over nodes
for(int i=0; i<graph->size(); i++) {
const mxnet::ext::Node *node = graph->getNode(i);
//get shape/type if available
std::string shape;
int dtype = -1;
if(node->attrs.count("shape") > 0)
shape = node->attrs.at("shape");
if(node->attrs.count("dtype") > 0)
dtype = std::stoi(node->attrs.at("dtype"));
//check if op dtype is float, and if option was specified to require float types
if((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) {
//check if op is in allowlist
if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) {
// found op in allowlist, set value to -1 to include op in any subgraph
ids->at(i) = -1;
}
}
}
return MX_SUCCESS;
}
MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept,
const std::unordered_map<std::string, std::string>& options,
std::unordered_map<std::string, std::string>* attrs) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
std::string sg = subgraph->toString();
std::cout << "subgraph " << subgraph_id << ": " << std::endl;
std::cout << sg << std::endl;
// check if option `reject` was specified, and if so check if value is 'True'
if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) {
// if specified, reject the subgraph. this is only used for testing
*accept = false;
std::cout << "rejecting subgraph" << std::endl;
} else {
*accept = true;
std::cout << "accepting subgraph" << std::endl;
}
attrs->emplace("myKey","myVal");
return MX_SUCCESS;
}
REGISTER_PARTITIONER(myProp)
.addStrategy("strategy1", "_custom_subgraph_op")
.setSupportedOps("strategy1", mySupportedOps)
.setReviewSubgraph("strategy1", myReviewSubgraph);
class MySelector : public CustomOpSelector {
public:
MySelector(const mxnet::ext::Graph *graph,
const std::unordered_map<std::string, std::string>& options) :
graph_(graph), options_(options) {
for (auto kv : options) {
std::cout << "selector options: " << kv.first
<< " ==> " << kv.second << std::endl;
}
}
bool chooseNode(int nodeID) {
const mxnet::ext::Node *node = graph_->getNode(nodeID);
//get shape/type if available
std::string shape;
int dtype = -1;
if(node->attrs.count("shape") > 0)
shape = node->attrs.at("shape");
if(node->attrs.count("dtype") > 0)
dtype = std::stoi(node->attrs.at("dtype"));
//check if op dtype is float, and if option was specified to require float types
if((dtype == kFloat32 && options_.count("reqFloat") > 0) || options_.count("reqFloat") == 0) {
//check if op is in allowlist
if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) {
// found op in allowlist, return true to include op subgraph
return true;
}
}
return false;
}
bool Select(int nodeID) override {
return chooseNode(nodeID);
}
bool SelectInput(int nodeID, int input_nodeID) override {
return chooseNode(input_nodeID);
}
bool SelectOutput(int nodeID, int output_nodeID) override {
return chooseNode(output_nodeID);
}
virtual void Filter(std::vector<int>& candidates,
std::vector<int>& keep) {
keep.insert(keep.end(), candidates.begin(), candidates.end());
}
void Reset() override {}
private:
const mxnet::ext::Graph *graph_;
const std::unordered_map<std::string, std::string> options_;
};
MXReturnValue createSelector(const mxnet::ext::Graph *graph, CustomOpSelector** sel_inst,
const std::unordered_map<std::string, std::string>& options) {
*sel_inst = new MySelector(graph, options);
std::cout << "Info: selector created" << std::endl;
return MX_SUCCESS;
}
REGISTER_PARTITIONER(mySelect)
.addStrategy("strategy1", "_custom_subgraph_op")
.setCreateSelector("strategy1", createSelector)
.setReviewSubgraph("strategy1", myReviewSubgraph);
/* \brief a basic pass that adds a new input for subgraph ops */
MXReturnValue addInputPass(mxnet::ext::Graph *graph,
const std::unordered_map<std::string, std::string>& options) {
//find node with '_custom_subgraph_op' op type
for(int i=0; i<graph->size(); i++) {
mxnet::ext::Node* n = graph->getNode(i);
if(n->op.compare("_custom_subgraph_op") == 0) {
//set extra input
n->attrs[MX_STR_EXTRA_INPUTS] = std::to_string(1);
//create a new input Node
Node* input = graph->addNode(n->name + "_input", "null");
//set this node as an input in the graph
graph->inputs.push_back(input);
//connect new input to node
input->outputs.push_back({n,(int)(n->inputs.size())});
//connect node to new input
n->inputs.push_back({input,0});
// add a corresponding tensor for this input
input->alloc_arg({1},MXContext::CPU(0),kFloat32);
}
}
return MX_SUCCESS;
}
REGISTER_PASS(addInputPass)
.setBody(addInputPass);
MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
MX_ERROR_MSG << "MXNet version " << version << " not supported by custom library" << std::endl;
return MX_FAIL;
}
}