update from singa-doc PR#41
diff --git a/docs-site/website/versioned_docs/version-3.1.0/onnx.md b/docs-site/website/versioned_docs/version-3.1.0/onnx.md
index 5aae4b2..651a70e 100644
--- a/docs-site/website/versioned_docs/version-3.1.0/onnx.md
+++ b/docs-site/website/versioned_docs/version-3.1.0/onnx.md
@@ -23,66 +23,53 @@
 
 ### Loading an ONNX Model into SINGA
 
-After loading an ONNX model from disk by `onnx.load`, you need to update the
-model's batchsize, since for most models, they use a placeholder to represent
-its batchsize. We give an example here, as `update_batch_size`. You only need to
-update the batchsize of input and output, the shape of internal tensors will be
-inferred automatically.
+After loading an ONNX model from disk by `onnx.load`, You only need to update
+the batch-size of input using `tensor.PlaceHolder` after SINGA v3.0, the shape
+of internal tensors will be inferred automatically.
 
-Then, you can prepare the SINGA model by using `sonnx.prepare`. This function
-iterates and translates all the nodes within the ONNX model's graph into SINGA
-operators, loads all stored weights and infers each intermediate tensor's shape.
+Then, you should define a class inheriting from `sonnx.SONNXModel` and implement
+two methods `forward` for forward work and `train_one_batch` for training work.
+After you call `model.compile`, the SONNX iterates and translates all the nodes
+within the ONNX model's graph into SINGA operators, loads all stored weights and
+infers each intermediate tensor's shape.
 
 ```python3
 import onnx
 from singa import device
 from singa import sonnx
 
-# if the input has multiple tensors? can put this function inside prepare()?
-def update_batch_size(onnx_model, batch_size):
-    model_input = onnx_model.graph.input[0]
-    model_input.type.tensor_type.shape.dim[0].dim_value = batch_size
-    model_output = onnx_model.graph.output[0]
-    model_output.type.tensor_type.shape.dim[0].dim_value = batch_size
-    return onnx_model
+class MyModel(sonnx.SONNXModel):
 
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        # Since SINGA model returns the output as a list,
+        # if there is only one output,
+        # you just need to take the first element.
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
 
 model_path = "PATH/To/ONNX/MODEL"
 onnx_model = onnx.load(model_path)
 
-# set batch size
-onnx_model = update_batch_size(onnx_model, 1)
-
-# convert onnx graph nodes into SINGA operators
+# convert onnx model into SINGA model
 dev = device.create_cuda_gpu()
-sg_ir = sonnx.prepare(onnx_model, device=dev)
+x = tensor.PlaceHolder(INPUT.shape, device=dev)
+model = MyModel(onnx_model)
+model.compile([x], is_train=False, use_graph=True, sequential=True)
 ```
 
 ### Inference SINGA model
 
-Once the model is created, you can do inference by calling `sg_ir.run`. The
-input and output must be SINGA `Tensor` instances. Since SINGA model returns the
-output as a list, if there is only one output, you just need to take the first
-element from the output.
+Once the model is created, you can do inference by calling `model.forward`. The
+input and output must be SINGA `Tensor` instances.
 
 ```python3
-# can warp the following code in prepare()
-# and provide a flag training=True/False?
-
-class Infer:
-
-
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-
-    def forward(self, x):
-        return sg_ir.run([x])[0]
-
-
-data = get_dataset()
-x = tensor.Tensor(device=dev, data=data)
-
-model = Infer(sg_ir)
+x = tensor.Tensor(device=dev, data=INPUT)
 y = model.forward(x)
 ```
 
@@ -100,443 +87,99 @@
 
 ### Re-training an ONNX model
 
-To train (or refine) an ONNX model using SINGA, you need to set the internal
-tensors to be trainable
+To train (or refine) an ONNX model using SINGA, you need to implement the
+`train_one_batch` from `sonnx.SONNXModel` and mark the `is_train=True` when
+calling `model.compile`.
 
 ```python3
-class Infer:
+from singa import opt
+from singa import autograd
 
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        ## can wrap these codes in sonnx?
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
+class MyModel(sonnx.SONNXModel):
 
-    def forward(self, x):
-        return sg_ir.run([x])[0]
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
 
-autograd.training = False
-model = Infer(sg_ir)
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
 
-autograd.training = True
-# then you training the model like normal
-# give more details??
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        if dist_option == 'fp32':
+            self.optimizer.backward_and_update(loss)
+        elif dist_option == 'fp16':
+            self.optimizer.backward_and_update_half(loss)
+        elif dist_option == 'partialUpdate':
+            self.optimizer.backward_and_partial_update(loss)
+        elif dist_option == 'sparseTopK':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == 'sparseThreshold':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 ```
 
 ### Transfer-learning an ONNX model
 
-You also can append some layers to the end of ONNX model to do
-transfer-learning. The `last_layers` means you cut the ONNX layers from [0,
-last_layers]. Then you can append more layers by the normal SINGA model.
+You also can append some layers to the end of the ONNX model to do
+transfer-learning. The `last_layers` accept a negative integer indicating the
+layer to cut off from. For example, `-1` means cut off after the final output(do
+not cut off any layer), `-2` means you cut off after the last second layer.
 
 ```python3
-class Trans:
-
-    def __init__(self, sg_ir, last_layers):
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(500, 128, bias=False)
-        self.append_linear2 = autograd.Linear(128, 32, bias=False)
-        self.append_linear3 = autograd.Linear(32, 10, bias=False)
-
-    def forward(self, x):
-        y = sg_ir.run([x], last_layers=self.last_layers)[0]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        return y
-
-autograd.training = False
-model = Trans(sg_ir, -1)
-
-# then you training the model like normal
-```
-
-## A Full Example
-
-This part introduces the usage of SINGA ONNX by using the mnist example. In this
-section, the examples of how to export, load, inference, re-training, and
-transfer-learning the minist model are displayed. You can try this part
-[here](https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq).
-
-### Load dataset
-
-Firstly, you need to import some necessary libraries and define some auxiliary
-functions for downloading and preprocessing the dataset:
-
-```python
-import os
-import urllib.request
-import gzip
-import numpy as np
-import codecs
-
-from singa import device
-from singa import tensor
 from singa import opt
 from singa import autograd
-from singa import sonnx
-import onnx
 
+class MyModel(sonnx.SONNXModel):
 
-def load_dataset():
-    train_x_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
-    train_y_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
-    valid_x_url = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
-    valid_y_url = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
-    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
-        np.float32)
-    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
-        np.float32)
-    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
-        np.float32)
-    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
-        np.float32)
-    return train_x, train_y, valid_x, valid_y
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+        self.linear = layer.Linear(1000, 3)
 
-
-def check_exist_or_download(url):
-
-    download_dir = '/tmp/'
-
-    name = url.rsplit('/', 1)[-1]
-    filename = os.path.join(download_dir, name)
-    if not os.path.isfile(filename):
-        print("Downloading %s" % url)
-        urllib.request.urlretrieve(url, filename)
-    return filename
-
-
-def read_label_file(path):
-    with gzip.open(path, 'rb') as f:
-        data = f.read()
-        assert get_int(data[:4]) == 2049
-        length = get_int(data[4:8])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape(
-            (length))
-        return parsed
-
-
-def get_int(b):
-    return int(codecs.encode(b, 'hex'), 16)
-
-
-def read_image_file(path):
-    with gzip.open(path, 'rb') as f:
-        data = f.read()
-        assert get_int(data[:4]) == 2051
-        length = get_int(data[4:8])
-        num_rows = get_int(data[8:12])
-        num_cols = get_int(data[12:16])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
-            (length, 1, num_rows, num_cols))
-        return parsed
-
-
-def to_categorical(y, num_classes):
-    y = np.array(y, dtype="int")
-    n = y.shape[0]
-    categorical = np.zeros((n, num_classes))
-    categorical[np.arange(n), y] = 1
-    categorical = categorical.astype(np.float32)
-    return categorical
-```
-
-### MNIST model
-
-Then you can define a class called **CNN** to construct the mnist model which
-consists of several convolution, pooling, fully connection and relu layers. You
-can also define a function to calculate the **accuracy** of our result. Finally,
-you can define a **train** and a **test** function to handle the training and
-prediction process.
-
-```python
-class CNN:
-    def __init__(self):
-        self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
-        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
-        self.linear1 = autograd.Linear(4 * 4 * 50, 500, bias=False)
-        self.linear2 = autograd.Linear(500, 10, bias=False)
-        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
-        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
-
-    def forward(self, x):
-        y = self.conv1(x)
-        y = autograd.relu(y)
-        y = self.pooling1(y)
-        y = self.conv2(y)
-        y = autograd.relu(y)
-        y = self.pooling2(y)
-        y = autograd.flatten(y)
-        y = self.linear1(y)
-        y = autograd.relu(y)
-        y = self.linear2(y)
+    def forward(self, *x):
+        # cut off after the last third layer
+        # and append a linear layer
+        y = super(MyModel, self).forward(*x, last_layers=-3)[0]
+        y = self.linear(y)
         return y
 
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        if dist_option == 'fp32':
+            self.optimizer.backward_and_update(loss)
+        elif dist_option == 'fp16':
+            self.optimizer.backward_and_update_half(loss)
+        elif dist_option == 'partialUpdate':
+            self.optimizer.backward_and_partial_update(loss)
+        elif dist_option == 'sparseTopK':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == 'sparseThreshold':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        return out, loss
 
-def accuracy(pred, target):
-    y = np.argmax(pred, axis=1)
-    t = np.argmax(target, axis=1)
-    a = y == t
-    return np.array(a, "int").sum() / float(len(t))
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
 
-
-def train(model,
-          x,
-          y,
-          epochs=1,
-          batch_size=64,
-          dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    for i in range(epochs):
-        for b in range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + 1) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = model.forward(x_batch)
-            # onnx_model = sonnx.to_onnx([x_batch], [y])
-            # print('The model is:\n{}'.format(onnx_model))
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=0.001)
-            for p, gp in autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            if b % 1e2 == 0:
-                print("acc %6.2f loss, %6.2f" %
-                      (accuracy_rate, tensor.to_numpy(loss)[0]))
-    print("training completed")
-    return x_batch, output_batch
-
-def test(model, x, y, batch_size=64, dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    result = 0
-    for b in range(batch_number):
-        l_idx = b * batch_size
-        r_idx = (b + 1) * batch_size
-
-        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-        output_batch = model.forward(x_batch)
-        result += accuracy(tensor.to_numpy(output_batch),
-                           tensor.to_numpy(target_batch))
-
-    print("testing acc %6.2f" % (result / batch_number))
-```
-
-### Train mnist model and export it to onnx
-
-Now, you can train the mnist model and export its onnx model by calling the
-**soonx.to_onnx** function.
-
-```python
-def make_onnx(x, y):
-    return sonnx.to_onnx([x], [y])
-
-# create device
-dev = device.create_cuda_gpu()
-#dev = device.get_default_device()
-# create model
-model = CNN()
-# load data
-train_x, train_y, valid_x, valid_y = load_dataset()
-# normalization
-train_x = train_x / 255
-valid_x = valid_x / 255
-train_y = to_categorical(train_y, 10)
-valid_y = to_categorical(valid_y, 10)
-# do training
-autograd.training = True
-x, y = train(model, train_x, train_y, dev=dev)
-onnx_model = make_onnx(x, y)
-# print('The model is:\n{}'.format(onnx_model))
-
-# Save the ONNX model
-model_path = os.path.join('/', 'tmp', 'mnist.onnx')
-onnx.save(onnx_model, model_path)
-print('The model is saved.')
-```
-
-### Inference
-
-After you export the onnx model, you can find a file called **mnist.onnx** in
-the '/tmp' directory, this model, therefore, can be imported by other libraries.
-Now, if you want to import this onnx model into singa again and do the inference
-using the validation dataset, you can define a class called **Infer**, the
-forward function of Infer will be called by the test function to do inference
-for validation dataset. By the way, you should set the label of training to
-**False** to fix the gradient of autograd operators.
-
-When import the onnx model, you need to call **onnx.load** to load the onnx
-model firstly. Then the onnx model will be fed into the **soonx.prepare** to
-parse and initiate to a singa model(**sg_ir** in the code). The sg_ir contains a
-singa graph within it, and then you can run an step of inference by feeding
-input to its run function.
-
-```python
-class Infer:
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad= True
-            sg_ir.tensor_map[idx] = tens
-
-    def forward(self, x):
-        return sg_ir.run([x])[0] # we can run one step of inference by feeding input
-
-# load the ONNX model
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev) # parse and initiate to a singa model
-
-# inference
-autograd.training = False
-print('The inference result is:')
-test(Infer(sg_ir), valid_x, valid_y, dev=dev)
-```
-
-### Re-training
-
-Assume after import the model, you want to re-train the model again, we can
-define a function called **re_train**. Before we call this re_train function, we
-should set the label of training to **True** to make the autograde operators
-update their gradient. And after we finish the training, we set it as **False**
-again to call the test function doing inference.
-
-```python
-def re_train(sg_ir,
-             x,
-             y,
-             epochs=1,
-             batch_size=64,
-             dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    new_model = Infer(sg_ir)
-
-    for i in range(epochs):
-        for b in range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + 1) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = new_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=0.01)
-            for p, gp in autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            if b % 1e2 == 0:
-                print("acc %6.2f loss, %6.2f" %
-                      (accuracy_rate, tensor.to_numpy(loss)[0]))
-    print("re-training completed")
-    return new_model
-
-# load the ONNX model
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-# re-training
-autograd.training = True
-new_model = re_train(sg_ir, train_x, train_y, dev=dev)
-autograd.training = False
-test(new_model, valid_x, valid_y, dev=dev)
-```
-
-### Transfer learning
-
-Finally, if we want to do transfer-learning, we can define a function called
-**Trans** to append some layers after the onnx model. For demonstration, the
-code only appends several linear(fully connection) and relu after the onnx
-model. You can define a transfer_learning function to handle the training
-process of the transfer-learning model. And the label of training is the same as
-the previous one.
-
-```python
-class Trans:
-    def __init__(self, sg_ir, last_layers):
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(500, 128, bias=False)
-        self.append_linear2 = autograd.Linear(128, 32, bias=False)
-        self.append_linear3 = autograd.Linear(32, 10, bias=False)
-
-    def forward(self, x):
-        y = sg_ir.run([x], last_layers=self.last_layers)[0]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        return y
-
-def transfer_learning(sg_ir,
-             x,
-             y,
-             epochs=1,
-             batch_size=64,
-             dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    trans_model = Trans(sg_ir, -1)
-
-    for i in range(epochs):
-        for b in range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + 1) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-            output_batch = trans_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=0.07)
-            for p, gp in autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            if b % 1e2 == 0:
-                print("acc %6.2f loss, %6.2f" %
-                      (accuracy_rate, tensor.to_numpy(loss)[0]))
-    print("transfer-learning completed")
-    return trans_mode
-
-# load the ONNX model
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-# transfer-learning
-autograd.training = True
-new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
-autograd.training = False
-test(new_model, valid_x, valid_y, dev=dev)
+sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 ```
 
 ## ONNX model zoo
@@ -559,6 +202,9 @@
 | <b>[VGG16](https://github.com/onnx/models/tree/master/vision/classification/vgg)</b>                | [Simonyan et al.](https://arxiv.org/abs/1409.1556)      | Deep CNN model(up to 19 layers). Similar to AlexNet but uses multiple smaller kernel-sized filters that provides more accuracy when classifying images. <br>Top-5 error from paper - ~8%                                                  | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14kxgRKtbjPCKKsDJVNi3AvTev81Gp_Ds) |
 | <b>[ShuffleNet_V2](https://github.com/onnx/models/tree/master/vision/classification/shufflenet)</b> | [Simonyan et al.](https://arxiv.org/pdf/1707.01083.pdf) | Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6% | [![Open In Colab](https://colab.research.google.com/drive/19HfRu3YHP_H2z3BcZujVFRp23_J5XsuA?usp=sharing)                                                |
 
+We also give some re-training examples by using VGG and ResNet, please check
+`examples/onnx/training`.
+
 ### Object Detection
 
 Object detection models detect the presence of multiple objects in an image and
@@ -694,11 +340,13 @@
 
 ## Implementation
 
-The code of SINGA ONNX locates at `python/singa/soonx.py`. There are three main
-class, `SingaFrontend` and `SingaBackend` and `SingaRep`. `SingaFrontend`
-translates a SINGA model to ONNX model; `SingaBackend` translates a ONNX model
-to `SingaRep` object which stores all SINGA operators and tensors(the tensor in
-this doc means SINGA `Tensor`); `SingaRep` can be run like a SINGA model.
+The code of SINGA ONNX locates at `python/singa/soonx.py`. There are four main
+class, `SingaFrontend`, `SingaBackend`, `SingaRep` and `SONNXModel`.
+`SingaFrontend` translates a SINGA model to an ONNX model; `SingaBackend`
+translates an ONNX model to `SingaRep` object which stores all SINGA operators
+and tensors(the tensor in this doc means SINGA `Tensor`); `SingaRep` can be run
+like a SINGA model. `SONNXModel` inherits from `model.Model` which defines a
+unified API for SINGA.
 
 ### SingaFrontend
 
@@ -728,43 +376,35 @@
 ### SingaBackend
 
 The entry function of `SingaBackend` is `prepare` which checks the version of
-ONNX model and call `_onnx_model_to_singa_net` then.
+ONNX model and call `_onnx_model_to_singa_ops` then.
 
-The purpose of `_onnx_model_to_singa_net` is to get SINGA tensors and operators.
+The purpose of `_onnx_model_to_singa_ops` is to get SINGA tensors and operators.
 The tensors are stored in a dictionary by their name in ONNX, and operators are
-stored in queue by the form of
-`namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])`. For each
-operator, `name` is its ONNX node name; `op` is the ONNX node; `forward` is the
-SINGA operator's forward function; `handle` is prepared for some special
-operators such as Conv and Pooling which has `handle` object.
+stored in queue by the form of `namedtuple('SingaOps', ['node', 'operator'])`.
+For each operator, `node` is an instance from OnnxNode which is defined to store
+some basic information for an ONNX node; `operator` is the SINGA operator's
+forward function;
 
-The first step of `_onnx_model_to_singa_net` is to call `_init_graph_parameter`
-to get all tensors within the model. For trainable weights, it can init SINGA
-`Tensor` from `onnx_model.graph.initializer`. Please note, the weights may also
-be stored within graph's input or a ONNX node called `Constant`, SINGA can also
-handle these.
+The first step of `_onnx_model_to_singa_ops` has four steps, the first one is to
+call `_parse_graph_params` to get all tensors stored as `params`. Then call
+`_parse_graph_inputs_outputs` to get all input and output information stores as
+`inputs` and `outputs`. Finally, it iterators all nodes within the ONNX graph
+and parses it by `_onnx_node_to_singa_op` as SIGNA operators or layers and store
+them as `outputs`. Some weights are stored within an ONNX node called
+`Constant`, SONNX can handle them by `_onnx_constant_to_np` to store it into
+`params`.
 
-Though all weights are stored within ONNX model, the input of the model is
-unknown but its shape and type. So SINGA support two ways to init input, 1,
-generate random tensor by its shape and type, 2, allow the user to assign the
-input. The first way works fine for most models, however, for some model such as
-bert, the indices of matrix cannot be random generated otherwise it will incurs
-errors.
-
-Then, `_onnx_model_to_singa_net` iterators all nodes within ONNX graph to
-translate it to SIGNA operators. Also, `_rename_operators` defines the operators
-name mapping between SINGA and ONNX. `_special_operators` defines which function
-to be used to translate the operator. `_run_node` runs the generated SINGA model
-by its input tensors and store its output tensors for being used by later
-operators.
-
-This class finally return a `SingaRep` object and stores all SINGA tensors and
-operators within it.
+This class finally return a `SingaRep` object and stores above `params`,
+`inputs`, `outputs`, `layers`.
 
 ### SingaRep
 
 `SingaBackend` stores all SINGA tensors and operators. `run` accepts the input
-of the model and run the SINGA operators one by one following the operators
-queue. The user can use `last_layers` to decide to run the model till the last
-few layers. Set `all_outputs` as `False` to get only the final output, `True` to
-also get all the intermediate output.
+of the model and runs the SINGA operators one by one following the operators'
+queue. The user can use `last_layers` to cut off the model after the last few
+layers.
+
+### SONNXModel
+
+`SONNXModel` inherits from `sonnx.SONNXModel` and implements the method
+`forward` to provide a unified API with other SINGA models.