add onnx,autograd doc
diff --git a/docs-site/docs/autograd.md b/docs-site/docs/autograd.md
index 7917c94..517142a 100644
--- a/docs-site/docs/autograd.md
+++ b/docs-site/docs/autograd.md
@@ -8,7 +8,7 @@
 There are two typical ways to implement autograd, via symbolic differentiation
 like [Theano](http://deeplearning.net/software/theano/index.html) or reverse
 differentiation like
-[Pytorch](https://pytorch.org/docs/stable/notes/autograd.html). Singa follows
+[Pytorch](https://pytorch.org/docs/stable/notes/autograd.html). SINGA follows
 Pytorch way, which records the computation graph and apply the backward
 propagation automatically after forward propagation. The autograd algorithm is
 explained in details
@@ -18,9 +18,9 @@
 ## Relevant Modules
 
 There are three classes involved in autograd, namely `singa.tensor.Tensor`,
-`singa.autograd.Operation`, `singa.autograd.Layer` and `singa.module.Module`. In
-the rest of this article, we use tensor, operation, layer and module to refer to
-an instance of the respective class.
+`singa.autograd.Operation`, and `singa.autograd.Layer`. In the rest of this
+article, we use tensor, operation and layer to refer to an instance of the
+respective class.
 
 ### Tensor
 
@@ -68,13 +68,6 @@
 `Layer` manages (stores) the parameters and calls the corresponding `Operation`s
 to implement the transformation.
 
-### Module
-
-For every neural network, it can be a subclass of Module. It is used to buffer
-all the operations in the neural network and form a computational graph. SINGA
-will schedule the operations and memory allocation to make training more
-efficient while using less memory.
-
 ## Examples
 
 Multiple examples are provided in the
@@ -265,3 +258,9 @@
     if i % (niters / 10) == 0 and rank_in_global == 0:
         print("training loss = ", tensor.to_numpy(loss)[0], flush=True)
 ```
+
+### Python API
+
+Refer
+[here](https://singa.readthedocs.io/en/latest/docs/autograd.html#module-singa.autograd)
+for more details of Python API.
diff --git a/docs-site/docs/onnx.md b/docs-site/docs/onnx.md
index 088445f..47b5d67 100644
--- a/docs-site/docs/onnx.md
+++ b/docs-site/docs/onnx.md
@@ -7,20 +7,163 @@
 
 ONNX is an open format built to represent machine learning models, which enables
 an ability to transfer trained models between different deep learning
-frameworks. We have integrated the main functionality of ONNX into Singa, and
+frameworks. We have integrated the main functionality of ONNX into SINGA, and
 several basic operators have been supported. More operators are being
 developing.
 
-## Example: ONNX mnist on singa
+The supported [ONNX
+version}(https://github.com/onnx/onnx/blob/master/docs/Versioning.md) os SINGA
+is:
 
-We will introduce the onnx of singa by using the mnist example. In this section,
-the examples of how to export, load, inference, re-training, and
-transfer-learning the minist model will be displayed.
+| ONNX version | File format version | Opset version ai.onnx | Opset version ai.onnx.ml | Opset version ai.onnx.training |
+| ------------ | ------------------- | --------------------- | ------------------------ | ------------------------------ |
+| 1.6.0        | 6                   | 11                    | 2                        | -                              |
+
+## General usage
+
+The onnx in SINGA has supported the basic functionality, and please refer the
+following tutorials for general usage:
+
+### Loading an ONNX Model into SINGA
+
+This part introduces how to import and prepare a SINGA model from a ONNX model.
+After you load a ONNX model by `onnx.load`, you need to update the model's
+batchsize, since for most model, they uses a placeholder to represent its
+batchsize. We give an example here, as `update_batch_size`. You only needs to
+update the batchsize of input and output, the shape of inner tensor will be
+inferred automatically.
+
+Then, you can prepare the SINGA model by using `sonnx.prepare`. This function
+iteraters and translates all the nodes within the ONNX model's graph to SINGA
+operators, loads all stored weights and infers each intermediate tensor's shape.
+For the device used, please refer to the `device` section.
+
+```python3
+import onnx
+from singa import device
+from singa import sonnx
+
+def update_batch_size(onnx_model, batch_size):
+    model_input = onnx_model.graph.input[0]
+    model_input.type.tensor_type.shape.dim[0].dim_value = batch_size
+    model_output = onnx_model.graph.output[0]
+    model_output.type.tensor_type.shape.dim[0].dim_value = batch_size
+    return onnx_model
+
+
+model_path = "PATH/To/ONNX/MODEL"
+onnx_model = onnx.load(model_path)
+
+# set batch size
+onnx_model = update_batch_size(onnx_model, 1)
+
+# prepare the model
+dev = device.create_cuda_gpu()
+sg_ir = sonnx.prepare(onnx_model, device=dev)
+```
+
+### Inferernce SINGA model
+
+After you load and prepare a SINGA model, you can do the inference by calling
+`sg_ir.run` as the following code. The input and output must be SINGA `Tensor`,
+and since SINGA model returns the output as a list, so if you only have one
+output, you just take the first element from the output as `forward` of `Infer`
+class.
+
+```python3
+class Infer:
+
+
+    def __init__(self, sg_ir):
+        self.sg_ir = sg_ir
+
+    def forward(self, x):
+        return sg_ir.run([x])[0]
+
+
+data = get_dataset()
+x = tensor.Tensor(device=dev, data=data)
+
+model = Infer(sg_ir)
+y = model.forward(x)
+```
+
+### Saving an ONNX Model from SINGA
+
+Now, if you have a SINGA model, you can export it as ONNX model as following:
+
+```python3
+sonnx.to_onnx([x], [y])
+```
+
+### Re-training a ONNX model
+
+You also can re-training a ONNX model after you load it into SINGA as following
+code. Please node you should set all tensors of the SINGA model to enable them
+to store gradient by `tens.requires_grad = True` and `tens.stores_grad = True`.
+
+```python3
+class Infer:
+
+    def __init__(self, sg_ir):
+        self.sg_ir = sg_ir
+        for idx, tens in sg_ir.tensor_map.items():
+            # allow the tensors to be updated
+            tens.requires_grad = True
+            tens.stores_grad = True
+
+    def forward(self, x):
+        return sg_ir.run([x])[0]
+
+autograd.training = False
+model = Infer(sg_ir)
+
+# then you training the model like normal
+```
+
+### Transfer-learning a ONNX model
+
+You also can append some layers to the end of ONNX model to do transfer-learning
+like following. The `last_layers` means you cut the ONNX layers from [0,
+last_layers]. Then you can append more layers by the normal SINGA model.
+
+```python3
+class Trans:
+
+    def __init__(self, sg_ir, last_layers):
+        self.sg_ir = sg_ir
+        self.last_layers = last_layers
+        self.append_linear1 = autograd.Linear(500, 128, bias=False)
+        self.append_linear2 = autograd.Linear(128, 32, bias=False)
+        self.append_linear3 = autograd.Linear(32, 10, bias=False)
+
+    def forward(self, x):
+        y = sg_ir.run([x], last_layers=self.last_layers)[0]
+        y = self.append_linear1(y)
+        y = autograd.relu(y)
+        y = self.append_linear2(y)
+        y = autograd.relu(y)
+        y = self.append_linear3(y)
+        y = autograd.relu(y)
+        return y
+
+autograd.training = False
+model = Trans(sg_ir, -1)
+
+# then you training the model like normal
+```
+
+## Example: ONNX mnist on SINGA
+
+This part introduces the usage of SINGA ONNX by using the mnist example. In this
+section, the examples of how to export, load, inference, re-training, and
+transfer-learning the minist model are displayed. You can try this part
+[here](https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq).
 
 ### Load dataset
 
-Firstly, we import some necessary libraries and define some auxiliary functions
-for downloading and preprocessing the dataset:
+Firstly, you need to import some necessary libraries and define some auxiliary
+functions for downloading and preprocessing the dataset:
 
 ```python
 import os
@@ -102,10 +245,11 @@
 
 ### MNIST model
 
-We define a class called **CNN** to construct the mnist model which consists of
-several convolution, pooling, fully connection and relu layers. We also define a
-function to calculate the **accuracy** of our result. Finally, we define a
-**train** and a **test** function to handle the training and prediction process.
+Then you can define a class called **CNN** to construct the mnist model which
+consists of several convolution, pooling, fully connection and relu layers. You
+can also define a function to calculate the **accuracy** of our result. Finally,
+you can define a **train** and a **test** function to handle the training and
+prediction process.
 
 ```python
 class CNN:
@@ -193,7 +337,7 @@
 
 ### Train mnist model and export it to onnx
 
-Now, we can train the mnist model and export its onnx model by calling the
+Now, you can train the mnist model and export its onnx model by calling the
 **soonx.to_onnx** function.
 
 ```python
@@ -226,19 +370,19 @@
 
 ### Inference
 
-After we export the onnx model, we can find a file called **mnist.onnx** in the
-'/tmp' directory, this model, therefore, can be imported by other libraries.
-Now, if we want to import this onnx model into singa again and do the inference
-using the validation dataset, we can define a class called **Infer**, the
+After you export the onnx model, you can find a file called **mnist.onnx** in
+the '/tmp' directory, this model, therefore, can be imported by other libraries.
+Now, if you want to import this onnx model into singa again and do the inference
+using the validation dataset, you can define a class called **Infer**, the
 forward function of Infer will be called by the test function to do inference
-for validation dataset. By the way, we should set the label of training to
+for validation dataset. By the way, you should set the label of training to
 **False** to fix the gradient of autograd operators.
 
-When import the onnx model, we firstly call **onnx.load** to load the onnx
-model. Then the onnx model will be fed into the **soonx.prepare** to parse and
-initiate to a singa model(**sg_ir** in the code). The sg_ir contains a singa
-graph within it, and we can run an step of inference by feeding input to its run
-function.
+When import the onnx model, you need to call **onnx.load** to load the onnx
+model firstly. Then the onnx model will be fed into the **soonx.prepare** to
+parse and initiate to a singa model(**sg_ir** in the code). The sg_ir contains a
+singa graph within it, and then you can run an step of inference by feeding
+input to its run function.
 
 ```python
 class Infer:
@@ -265,7 +409,7 @@
 
 ### Re-training
 
-Assume after import the model, we want to re-train the model again, we can
+Assume after import the model, you want to re-train the model again, we can
 define a function called **re_train**. Before we call this re_train function, we
 should set the label of training to **True** to make the autograde operators
 update their gradient. And after we finish the training, we set it as **False**
@@ -321,11 +465,11 @@
 ### Transfer learning
 
 Finally, if we want to do transfer-learning, we can define a function called
-**Trans** to append some layers after the onnx model. For demonstration, we only
-append several linear(fully connection) and relu after the onnx model. We also
-define a transfer_learning function to handle the training process of the
-transfer-learning model. And the label of training is the same as the previous
-one.
+**Trans** to append some layers after the onnx model. For demonstration, the
+code only appends several linear(fully connection) and relu after the onnx
+model. You can define a transfer_learning function to handle the training
+process of the transfer-learning model. And the label of training is the same as
+the previous one.
 
 ```python
 class Trans:
@@ -391,178 +535,227 @@
 test(new_model, valid_x, valid_y, dev=dev)
 ```
 
-## Example: ONNX tiny_yolov2 on singa
+## ONNX model zoo
 
-Now, the onnx of Singa supports importing models from
-[Onnx Model Zoo](https://github.com/onnx/models). We will show you how to
-inmport a Tiny-Yolo-V2 model and verify the correctness of the model by using
-its test dataset.
+The [ONNX Model Zoo](https://github.com/onnx/models) is a collection of
+pre-trained, state-of-the-art models in the ONNX format contributed by community
+members. SINGA has supported several CV and NLP models now. More models are
+going to be supported soon.
 
-### Load model
+### Image Classification
 
-Firstly, we try to download the Tiny-Yolo-V2 model from the Onnx Model Zoo if it
-doesn't exist already, and then load this model:
+This collection of models take images as input, then classifies the major
+objects in the images into 1000 object categories such as keyboard, mouse,
+pencil, and many animals.
 
-```python
-def load_model():
-    url = 'https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/tiny_yolov2.tar.gz'
-    download_dir = '/tmp/'
-    filename = os.path.join(download_dir, 'tiny_yolov2', '.', 'Model.onnx')
-    with tarfile.open(check_exist_or_download(url), 'r') as t:
-        t.extractall(path=download_dir)
-    return filename
+| Model Class                                                                                    | Reference                                          | Description                                                                                                                                                                              | Link                                                                                                                                                    |
+| ---------------------------------------------------------------------------------------------- | -------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| <b>[MobileNet](https://github.com/onnx/models/tree/master/vision/classification/mobilenet)</b> | [Sandler et al.](https://arxiv.org/abs/1801.04381) | Light-weight deep neural network best suited for mobile and embedded vision applications. <br>Top-5 error from paper - ~10%                                                              | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1HsixqJMIpKyEPhkbB8jy7NwNEFEAUWAf) |
+| <b>[ResNet18](https://github.com/onnx/models/tree/master/vision/classification/resnet)</b>     | [He et al.](https://arxiv.org/abs/1512.03385)      | A CNN model (up to 152 layers). Uses shortcut connections to achieve higher accuracy when classifying images. <br> Top-5 error from paper - ~3.6%                                        | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1u1RYefSsVbiP4I-5wiBKHjsT9L0FxLm9) |
+| <b>[VGG16](https://github.com/onnx/models/tree/master/vision/classification/vgg)</b>           | [Simonyan et al.](https://arxiv.org/abs/1409.1556) | Deep CNN model(up to 19 layers). Similar to AlexNet but uses multiple smaller kernel-sized filters that provides more accuracy when classifying images. <br>Top-5 error from paper - ~8% | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14kxgRKtbjPCKKsDJVNi3AvTev81Gp_Ds) |
 
-def check_exist_or_download(url):
-    download_dir = '/tmp/'
-    name = url.rsplit('/', 1)[-1]
-    filename = os.path.join(download_dir, name)
-    if not os.path.isfile(filename):
-        print("Downloading %s" % url)
-        urllib.request.urlretrieve(url, filename)
-    return filename
+### Object Detection
 
-dev = device.create_cuda_gpu()
-model_path = load_model()
-onnx_model = onnx.load(model_path)
-```
+Object detection models detect the presence of multiple objects in an image and
+segment out areas of the image where the objects are detected.
 
-### Set batchsize and prepare model
+| Model Class                                                                                                       | Reference                                             | Description                                                                                                                        | Link                                                                                                                                                    |
+| ----------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| <b>[Tiny YOLOv2](https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny_yolov2)</b> | [Redmon et al.](https://arxiv.org/pdf/1612.08242.pdf) | A real-time CNN for object detection that detects 20 different classes. A smaller version of the more complex full YOLOv2 network. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11V4I6cRjIJNUv5ZGsEGwqHuoQEie6b1T) |
 
-Then since lots of example models don't indicate its batch size, we need to
-update it. After that, we can parse the onnx model into singa model:
+### Face Analysis
 
-```python
-def update_batch_size(onnx_model, batch_size):
-    model_input = onnx_model.graph.input[0]
-    model_input.type.tensor_type.shape.dim[0].dim_value = batch_size
-    return onnx_model
+Face detection models identify and/or recognize human faces and emotions in
+given images.
 
-# set batch size
-onnx_model = update_batch_size(onnx_model, 1)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-```
+| Model Class                                                                                               | Reference                                          | Description                                                                                                                         | Link                                                                                                                                                    |
+| --------------------------------------------------------------------------------------------------------- | -------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| <b>[ArcFace](https://github.com/onnx/models/tree/master/vision/body_analysis/arcface)</b>                 | [Deng et al.](https://arxiv.org/abs/1801.07698)    | A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for input face images. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qanaqUKGIDtifdzEzJOHjEj4kYzA9uJC) |
+| <b>[Emotion FerPlus](https://github.com/onnx/models/tree/master/vision/body_analysis/emotion_ferplus)</b> | [Barsoum et al.](https://arxiv.org/abs/1608.01041) | Deep CNN for emotion recognition trained on images of faces.                                                                        | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1XHtBQGRhe58PDi4LGYJzYueWBeWbO23r) |
 
-### Define inference
+### Machine Comprehension
 
-For clearness, we define a Infer functin to hold the model's forward process:
+This subset of natural language processing models that answer questions about a
+given context paragraph.
 
-```python
-class Infer:
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
-
-    def forward(self, x):
-        return sg_ir.run([x])[0]
-
-# inference
-autograd.training = False
-model = Infer(sg_ir)
-```
-
-### Load dataset, run and verify
-
-Finally, we load the test dataset which is provided by Onnx Model Zoo, do the
-inference and verify its correctness.
-
-```python
-def load_dataset(test_data_dir):
-    # Load inputs
-    inputs = []
-    inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
-    for i in range(inputs_num):
-        input_file = os.path.join(test_data_dir, 'input_{}.pb'.format(i))
-        tensor = onnx.TensorProto()
-        with open(input_file, 'rb') as f:
-            tensor.ParseFromString(f.read())
-        inputs.append(numpy_helper.to_array(tensor))
-
-    # Load reference outputs
-    ref_outputs = []
-    ref_outputs_num = len(glob.glob(os.path.join(test_data_dir, 'output_*.pb')))
-    for i in range(ref_outputs_num):
-        output_file = os.path.join(test_data_dir, 'output_{}.pb'.format(i))
-        tensor = onnx.TensorProto()
-        with open(output_file, 'rb') as f:
-            tensor.ParseFromString(f.read())
-        ref_outputs.append(numpy_helper.to_array(tensor))
-    return inputs, ref_outputs
-
-inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'tiny_yolov2', 'test_data_set_0'))
-x_batch = tensor.Tensor(device=dev, data=inputs[0])
-outputs = model.forward(x_batch)
-
-# Compare the results with reference outputs.
-for ref_o, o in zip(ref_outputs, outputs):
-    np.testing.assert_almost_equal(ref_o, o)
-```
+| Model Class                                                                                           | Reference                                             | Description                                                                     | Link                                                                                                                                                    |
+| ----------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| <b>[BERT-Squad](https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad)</b> | [Devlin et al.](https://arxiv.org/pdf/1810.04805.pdf) | This model answers questions based on the context of the given input paragraph. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kud-lUPjS_u-TkDAzihBTw0Vqr0FjCE-) |
 
 ## Supported operators
 
 The following operators are supported:
 
-| Operation          | Comments                                  |
-| ------------------ | ----------------------------------------- |
-| Conv               | not support SAME_UPPER and SAME_LOWER yet |
-| Relu               | -                                         |
-| Constant           | -                                         |
-| MaxPool            | -                                         |
-| AveragePool        | -                                         |
-| Softmax            | -                                         |
-| Sigmoid            | -                                         |
-| Add                | -                                         |
-| MatMul             | -                                         |
-| BatchNormalization | -                                         |
-| Concat             | -                                         |
-| Flatten            | -                                         |
-| Add                | -                                         |
-| Gemm               | -                                         |
-| Reshape            | -                                         |
-| Sum                | -                                         |
-| Cos                | -                                         |
-| Cosh               | -                                         |
-| Sin                | -                                         |
-| Sinh               | -                                         |
-| Tan                | -                                         |
-| Tanh               | -                                         |
-| Acos               | -                                         |
-| Acosh              | -                                         |
-| Asin               | -                                         |
-| Asinh              | -                                         |
-| Atan               | -                                         |
-| Atanh              | -                                         |
-| Selu               | -                                         |
-| Elu                | -                                         |
-| Equal              | -                                         |
-| Less               | -                                         |
-| Sign               | -                                         |
-| Div                | -                                         |
-| Sub                | -                                         |
-| Sqrt               | -                                         |
-| Log                | -                                         |
-| Greater            | -                                         |
-| HardSigmoid        | -                                         |
-| Identity           | -                                         |
-| Softplus           | -                                         |
-| Softsign           | -                                         |
-| Mean               | -                                         |
-| Pow                | -                                         |
-| Clip               | -                                         |
-| PRelu              | -                                         |
-| Mul                | -                                         |
-| Transpose          | -                                         |
-| Max                | -                                         |
-| Min                | -                                         |
-| Shape              | -                                         |
-| And                | -                                         |
-| Or                 | -                                         |
-| Xor                | -                                         |
-| Not                | -                                         |
-| Neg                | -                                         |
-| Reciprocal         | -                                         |
-| LeakyRelu          | -                                         |
-| GlobalAveragePool  | -                                         |
+- Conv
+- Relu
+- Constant
+- MaxPool
+- AveragePool
+- Softmax
+- Sigmoid
+- Add
+- MatMul
+- BatchNormalization
+- Concat
+- Flatten
+- Add
+- Gemm
+- Reshape
+- Sum
+- Cos
+- Cosh
+- Sin
+- Sinh
+- Tan
+- Tanh
+- Acos
+- Acosh
+- Asin
+- Asinh
+- Atan
+- Atanh
+- Selu
+- Elu
+- Equal
+- Less
+- Sign
+- Div
+- Sub
+- Sqrt
+- Log
+- Greater
+- HardSigmoid
+- Identity
+- Softplus
+- Softsign
+- Mean
+- Pow
+- Clip
+- PRelu
+- Mul
+- Transpose
+- Max
+- Min
+- Shape
+- And
+- Or
+- Xor
+- Not
+- Neg
+- Reciprocal
+- LeakyRelu
+- GlobalAveragePool
+- ConstantOfShape
+- Dropout
+- ReduceSum
+- ReduceMean
+- LeakyRelu
+- GlobalAveragePool
+- Squeeze
+- Unsqueeze
+- Slice
+- Ceil
+- Split
+- Gather
+- Tile
+- NonZero
+- Cast
+- OneHot
+
+### Special comments for ONNX backend
+
+- Conv, MaxPool and AveragePool
+
+  Input must be 1d`(N*C*H)` and 2d(`N*C*H*W`) shape and `dilation` must be 1.
+
+- BatchNormalization
+
+  `epsilon` is 1e-05 and cannot be changed.
+
+- Cast
+
+  Only support float32 and int32, other types are casted to these two types.
+
+- Squeeze and Unsqueeze
+
+  If you encounter errors when you `Squeeze` or `Unsqueeze` between `Tensor` and
+  Scalar, please report to us.
+
+- Empty tensor Empty tensor is illegal in SINGA.
+
+## Implementation
+
+The code of SINGA ONNX locates at `python/singa/soonx.py`. There are three main
+class, `SingaFrontend` and `SingaBackend` and `SingaRep`. `SingaFrontend`
+translates a SINGA model to ONNX model; `SingaBackend` translates a ONNX model
+to `SingaRep` object which stores all SINGA operators and tensors(the tensor in
+this doc means SINGA `Tensor`); `SingaRep` can be run like a SINGA model.
+
+### SingaFrontend
+
+The entry function of `SingaFrontend` is `singa_to_onnx_model` which also is
+called `to_onnx`. `singa_to_onnx_model` creates the ONNX model, and it also
+create a ONNX graph by using `singa_to_onnx_graph`.
+
+`singa_to_onnx_graph` accepts the output of the model, and recursively iterate
+the SINGA model's graph from the output to get all operators to form a queue.
+The input and intermediate tensors, i.e, trainable weights, of the SINGA model
+is picked up at the same time. The input is stored in `onnx_model.graph.input`;
+the output is stored in `onnx_model.graph.output`; and the trainable weights are
+stored in `onnx_model.graph.initializer`.
+
+Then the SINGA operator in the queue is translated to ONNX operators one by one.
+`_rename_operators` defines the operators name mapping between SINGA and ONNX.
+`_special_operators` defines which function to be used to translate the
+operator.
+
+In addition, some operators in SINGA has different definition with ONNX, that
+is, ONNX regards some attributes of SINGA operators as input, so
+`_unhandled_operators` defines which function to handle the special operator.
+
+Since the bool type is regarded as int32 in SINGA, `_bool_operators` defines the
+operators to be changed as bool type.
+
+### SingaBackend
+
+The entry function of `SingaBackend` is `prepare` which checks the version of
+ONNX model and call `_onnx_model_to_singa_net` then.
+
+The purpose of `_onnx_model_to_singa_net` is to get SINGA tensors and operators.
+The tensors are stored in a dictionary by their name in ONNX, and operators are
+stored in queue by the form of
+`namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])`. For each
+operator, `name` is its ONNX node name; `op` is the ONNX node; `forward` is the
+SINGA operator's forward function; `handle` is prepared for some special
+operators such as Conv and Pooling which has `handle` object.
+
+The first step of `_onnx_model_to_singa_net` is to call `_init_graph_parameter`
+to get all tensors within the model. For trainable weights, it can init SINGA
+`Tensor` from `onnx_model.graph.initializer`. Please note, the weights may also
+be stored within graph's input or a ONNX node called `Constant`, SINGA can also
+handle these.
+
+Though all weights are stored within ONNX model, the input of the model is
+unknown but its shape and type. So SINGA support two ways to init input, 1,
+generate random tensor by its shape and type, 2, allow the user to assign the
+input. The first way works fine for most models, however, for some model such as
+bert, the indices of matrix cannot be random generated otherwise it will incurs
+errors.
+
+Then, `_onnx_model_to_singa_net` iterators all nodes within ONNX graph to
+translate it to SIGNA operators. Also, `_rename_operators` defines the operators
+name mapping between SINGA and ONNX. `_special_operators` defines which function
+to be used to translate the operator. `_run_node` runs the generated SINGA model
+by its input tensors and store its output tensors for being used by later
+operators.
+
+This class finally return a `SingaRep` object and stores all SINGA tensors and
+operators within it.
+
+### SingaRep
+
+`SingaBackend` stores all SINGA tensors and operators. `run` accepts the input
+of the model and run the SINGA operators one by one following the operators
+queue. The user can use `last_layers` to decide to run the model till the last
+few layers. Set `all_outputs` as `False` to get only the final output, `True` to
+also get all the intermediate output.