tree: 3dca997dfac62a0a405dba4e3315e76d5fde333c [path history] [tgz]
  1. example_connection.png
  2. Makefile
  3. pass_lib.cc
  4. README.md
  5. test_pass.py
example/extensions/lib_pass/README.md

Custom Graph Pass Example and Tutorial

Introduction

Adding custom graph passes in MXNet used to require deep understanding of the MXNet backend, including nnvm pass registration and other internal classes, followed by recompiling MXNet from source. This feature allows adding custom graph passes by dynamically loading external libraries at runtime.

This custom graph pass feature enables users to write custom model modification strategies without compiling against all of MXNet header files and dependencies. When a library containing custom passes is loaded dynamically, the components found in the library will be registered in MXNet so that users can use those natively just like other built-in components.

Getting Started

Have MXNet Ready

To run the following example, the build type of MXNet doesn’t matter since the custom pass doesn’t interact with the execution of other native MXNet features. Note that if you want to use your custom pass with models running on GPU, you still need an MXNet CUDA build.

Run An Example

You can start getting familiar with custom passes by running an example provided in the example/extensions/lib_pass directory. The myPass example just prints out the graph. Go to the lib_pass directory and follow these steps:

  1. Run make. The Makefile will generate the dynamic library libpass_lib.so which is compiled from the pass_lib.cc file. This is the library you are going to load that contains everything for the custom pass.
  2. Run python test_pass.py. It’ll first load the above library, find the components, register them in the MXNet backend, then execute the pass on the model and execute the operators like a regular MXNet operator and output the result. Below is the output when running the python test_pass.py command. Notice that it loads 1 pass: myPass.
[10:38:03] src/c_api/c_api.cc:286: Found 0 operators in library
[10:38:03] src/c_api/c_api.cc:785: Found 0 partitioners in library
[07:14:00] src/c_api/c_api.cc:887: Found 1 graph passes in library
[07:14:00] src/c_api/c_api.cc:902:       Graph Pass [0] myPass

Basic Files For Custom Pass Library

  • lib_pass/pass_lib.cc: This file has a source code implementation of all required components to make a custom pass, it also shows registration of them so that they can be loaded by MXNet.
  • lib_pass/Makefile: This file compiles the source code to a dynamic shared library, with a header file include/mxnet/lib_api.h from MXNet source code. Currently the custom pass is compatible with C++11 and above.
  • lib_pass/test_pass.py: This file calls mx.library.load(‘libpass_lib.so’) to load the library containing the custom components, executes the pass on the model using the optimize_for API, and prints outputs of the forward passes. The outputs should be the same as the regular MXNet forward pass without running the pass.
  • include/mxnet/lib_api.h: This file from MXNet source code is the single header file needed to include all necessary data types and function prototypes for writing a custom library. You can either specify the include path in the Makefile, or copy the header file over to example/extensions/lib_pass folder. Note that apart from this header, the custom library is independent of MXNet source.

Writing Custom Pass Library

To build your own library containing a custom pass, compose a C++ source file like mypass_lib.cc, include lib_api.h header file, and write your custom pass with these essential functions:

  • initialize - Library Initialization Function
  • REGISTER_PASS - Pass Registration Macro
  • graphPass - Pass Implementation Then compile it to the mypass_lib.so dynamic library using the following command:
g++ -shared -fPIC -std=c++11 mypass_lib.cc -o libmypass_lib.so -I ../../../include/mxnet

Finally, you can write a Python script to load the library and execute your pass on a model:

import mxnet as mx
mx.library.load(‘libmypass_lib.so’)
sym, _, _ = mx.model.load_checkpoint('mymodel', 0) 
# Symbol/Module flow
sym2 = sym.optimize_for("myPass")
# Gluon flow 1
sym_block = nn.SymbolBlock(sym, inputs)
sym_block.hybridize(backend='myPass')
# Gluon flow 2
sym_block = nn.SymbolBlock(sym, inputs)
sym_block.optimize_for(x, backend='myPass')

Using a Custom Pass Library

APIs in MXNet are available in both Symbol and Gluon APIs. For the Symbol API, optimize_for can be called on Symbol objects to run the graph pass and return a new Symbol.

sym.optimize_for(backend, args=None, aux=None, ctx=None, **kwargs)

The optimize_for API takes at least 1 argument, backend which is a string that identifies which backend to use to optimize the model. The args and aux arguments are optional and take a list of NDArray or dict of str to NDArray. They are used to infer shapes and types and before executing the graph pass. The ctx argument is optional and takes a device context to infer storage types. It also takes any other user-specified options that will be passed to the backend APIs.

For the Gluon API, hybridize can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.

block.hybridize(backend=None, **kwargs)

The hybridize function prepares the HybridBlock to be converted into a backend symbol. The backend argument is a string that identifies which pass that will be executed on the model. **kwargs might contain other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.

If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the optimize_for API that combines the work done in the hybridize API with part of the work done in the forward pass.

block.optimize_for(x, backend=None, **kwargs)

When the optimize_for API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass.

block.optimize_for(x, backend='myPass')
block.export('optimized')

But you can also use optimize_for in place of hybridize and run inference immediately after too.

block.optimize_for(x, backend='myPass')
block(x)

Writing A Custom Graph Pass

There are several essential building blocks for making a custom pass:

  • initialize:
    • This function is the library initialization function necessary for any dynamic libraries. It lets you check if the user is using a compatible version of MXNet. Note that this version parameter is passed from MXNet when library is loaded.
            MXReturnValue initialize(int version)
  • graphPass:
    • This function provides a copy of the model graph, and any specific options from the user.
            MXReturnValue graphPass(
                mxnet::ext::Graph *g,
                const std::unordered_map<std::string, std::string>& options)
  • REGISTER_PASS(my_pass_name):
    • This macro registers the custom pass and its properties to MXNet by its name. The argument to setBody is the graphPass function.
            REGISTER_PASS(my_pass_name)
            .setBody(graphPass);

Let’s take a closer look at those registry functions:

  • graphPass: This function takes two arguments. The first argument is the Graph of the model architecture, where nodes are inputs/params/weights and edges are data dependencies. The second argument is the map of options specified by the user. Users can pass custom options to the pass and they are passed to this function in the options map.

Graph representation

The Graph class represents the model's architecture. Each Node in the graph represents an operator or weight (ie. args/aux param). Since an operator in MXNet can take multiple inputs and produce multiple outputs, each input/output is represented by a NodeEntry. A Node contains the following:

  • op - [string] operator name
  • name - [string] unique node name
  • inputs - [vector of NodeEntry] set of inputs to the node
  • outputs - [vector of NodeEntry] set of outputs from the node
  • subgraph - [vector of Graph] set of subgraphs in the node
  • attrs - [map of string to string] set of attributes for the node

The inputs are a set of NodeEntry where each contains a pointer to a Node that produces the data, and an entry that is the index of the output on the other Node. Conversely, the outputs are a set of NodeEntry where each contains a pointer to aNode that consumes the data, and and entry that is the index of the input on the other Node. This bidirectional dependency will enable you to easily traverse the graph.

A Graph contains the following:

  • nodes - [vector of Node] set of nodes in the graph
  • inputs - [vector of Node] set of inputs to the graph
  • outputs - [vector of NodeEntry] set of outputs from the graph
  • attrs - [map of string to JSON object] set of attributes for the graph

The nodes are all the nodes in the graph (superset). The inputs are only those nodes that are model inputs (ie. input image) or weights (ie. arg/aux params). The outputs are the outputs from the operators in the model that are true outputs of the model (ie. prediction results).

Heres an example creating a new node and adding it to the graph:

g->addNode("myConv","Convolution");

Heres an example creating an edge between two nodes:

n1->outputs.push_back({n2,1});
n2->inputs.push_back({n1,0});

Here node n1 produces an output at index 0 that is consumed by node n2 on the input at index 1.

example connection

Some graph passes require allocating new NDArrays to add/replace model params. The alloc_arg and alloc_aux APIs enable allocating new NDArrays and integrate them with the model args and aux params. Both APIs have the following signature:

    MXTensor* alloc_xxx(const std::vector<int64_t>& shapes,
                        const MXContext &ctx,
                        MXDType dtype)

This function can be called on a node in the graph to allocate a tensor for that node like:

node->alloc_arg({1},MXContext::CPU(0),kFloat32);

It adds a new param to the appropriate arg/aux set when the graph pass returns. If you wish to remove an existing param, just remove the node in the graph corresponding to that param. It will be deleted after the pass completes and removed from the dictionary of args or aux (whichever it is a member of).

Parsing a JSON string

To simplify custom libraries, basic JSON parsing utility functions have been implemented in the lib_api.h header file. You create a JsonParser object and parse the string by calling the parse_to_json API like:

JsonVal json_val = JsonVal::parse(json);

A JsonVal is a class that represents the nodes in a JSON structure. You can check the type of a node (num, str, list, or map) by comparing the JsonVal.type to STR, NUM, LIST, or MAP. Then you can get that value from the node like:

switch(json_val.type) {
  case STR:
    std::string str = json_val.str;
    break;
  case NUM:
    int num = json_val.num;
    break;
  case LIST:
    std::vector<JsonVal> list = json_val.list;
    break;
  case MAP:
    std::map<JsonVal, JsonVal> map = json_val.map;
    break;
  default:
    // error
}

You call the dump function on a JsonVal object like json_val.dump() to get a JSON-compatible string. There are also convenience constructors for creating JsonVal objects for strings and numbers like JsonVal("myKey") or JsonVal(42). This makes it easy to get specific keys from a map like json_val.map[JsonVal("nodes")].