update from singa-doc PR#41
diff --git a/content/docs/next/onnx.html b/content/docs/next/onnx.html
index 62b63bb..425ce07 100644
--- a/content/docs/next/onnx.html
+++ b/content/docs/next/onnx.html
@@ -78,60 +78,47 @@
 </table>
 <h2><a class="anchor" aria-hidden="true" id="general-usage"></a><a href="#general-usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>General usage</h2>
 <h3><a class="anchor" aria-hidden="true" id="loading-an-onnx-model-into-singa"></a><a href="#loading-an-onnx-model-into-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Loading an ONNX Model into SINGA</h3>
-<p>After loading an ONNX model from disk by <code>onnx.load</code>, you need to update the
-model's batchsize, since for most models, they use a placeholder to represent
-its batchsize. We give an example here, as <code>update_batch_size</code>. You only need to
-update the batchsize of input and output, the shape of internal tensors will be
-inferred automatically.</p>
-<p>Then, you can prepare the SINGA model by using <code>sonnx.prepare</code>. This function
-iterates and translates all the nodes within the ONNX model's graph into SINGA
-operators, loads all stored weights and infers each intermediate tensor's shape.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx
-from singa <span class="hljs-built_in">import</span> device
-from singa <span class="hljs-built_in">import</span> sonnx
+<p>After loading an ONNX model from disk by <code>onnx.load</code>, You only need to update
+the batch-size of input using <code>tensor.PlaceHolder</code> after SINGA v3.0, the shape
+of internal tensors will be inferred automatically.</p>
+<p>Then, you should define a class inheriting from <code>sonnx.SONNXModel</code> and implement
+two methods <code>forward</code> for forward work and <code>train_one_batch</code> for training work.
+After you call <code>model.compile</code>, the SONNX iterates and translates all the nodes
+within the ONNX model's graph into SINGA operators, loads all stored weights and
+infers each intermediate tensor's shape.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-keyword">import</span> onnx
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
 
-<span class="hljs-comment"># if the input has multiple tensors? can put this function inside prepare()?</span>
-def update_batch_size(onnx_model, batch_size):
-    <span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>]
-    model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    <span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>]
-    model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    return onnx_model
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span><span class="hljs-params">(sonnx.SONNXModel)</span>:</span>
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, onnx_model)</span>:</span>
+        super(MyModel, self).__init__(onnx_model)
 
-<span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
-<span class="hljs-attr">onnx_model</span> = onnx.load(model_path)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, *x)</span>:</span>
+        y = super(MyModel, self).forward(*x)
+        <span class="hljs-comment"># Since SINGA model returns the output as a list,</span>
+        <span class="hljs-comment"># if there is only one output,</span>
+        <span class="hljs-comment"># you just need to take the first element.</span>
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-<span class="hljs-comment"># set batch size</span>
-<span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(self, x, y)</span>:</span>
+        <span class="hljs-keyword">pass</span>
 
-<span class="hljs-comment"># convert onnx graph nodes into SINGA operators</span>
-<span class="hljs-attr">dev</span> = device.create_cuda_gpu()
-<span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span>
+model_path = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
+onnx_model = onnx.load(model_path)
+
+<span class="hljs-comment"># convert onnx model into SINGA model</span>
+dev = device.create_cuda_gpu()
+x = tensor.PlaceHolder(INPUT.shape, device=dev)
+model = MyModel(onnx_model)
+model.compile([x], is_train=<span class="hljs-literal">False</span>, use_graph=<span class="hljs-literal">True</span>, sequential=<span class="hljs-literal">True</span>)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="inference-singa-model"></a><a href="#inference-singa-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference SINGA model</h3>
-<p>Once the model is created, you can do inference by calling <code>sg_ir.run</code>. The
-input and output must be SINGA <code>Tensor</code> instances. Since SINGA model returns the
-output as a list, if there is only one output, you just need to take the first
-element from the output.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-comment"># can warp the following code in prepare()</span>
-<span class="hljs-comment"># and provide a flag training=True/False?</span>
-
-<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
-        <span class="hljs-keyword">self</span>.sg_ir = sg_ir
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
-
-
-data = get_dataset()
-x = tensor.Tensor(device=dev, data=data)
-
-model = Infer(sg_ir)
-y = model.forward(x)
+<p>Once the model is created, you can do inference by calling <code>model.forward</code>. The
+input and output must be SINGA <code>Tensor</code> instances.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-attr">x</span> = tensor.Tensor(device=dev, data=INPUT)
+<span class="hljs-attr">y</span> = model.forward(x)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="saving-singa-model-into-onnx-format"></a><a href="#saving-singa-model-into-onnx-format" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Saving SINGA model into ONNX Format</h3>
 <p>Given the input tensors and the output tensors generated by the operators the
@@ -142,410 +129,93 @@
 sonnx.to_onnx([<span class="hljs-symbol">x</span>], [<span class="hljs-symbol">y</span>])
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="re-training-an-onnx-model"></a><a href="#re-training-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training an ONNX model</h3>
-<p>To train (or refine) an ONNX model using SINGA, you need to set the internal
-tensors to be trainable</p>
-<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
+<p>To train (or refine) an ONNX model using SINGA, you need to implement the
+<code>train_one_batch</code> from <code>sonnx.SONNXModel</code> and mark the <code>is_train=True</code> when
+calling <code>model.compile</code>.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-comment">## can wrap these codes in sonnx?</span>
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad = <span class="hljs-literal">True</span>
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
 
-autograd.training = <span class="hljs-literal">False</span>
-model = Infer(sg_ir)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x)
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-autograd.training = <span class="hljs-literal">True</span>
-<span class="hljs-comment"># then you training the model like normal</span>
-<span class="hljs-comment"># give more details??</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
+
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
+
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="transfer-learning-an-onnx-model"></a><a href="#transfer-learning-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer-learning an ONNX model</h3>
-<p>You also can append some layers to the end of ONNX model to do
-transfer-learning. The <code>last_layers</code> means you cut the ONNX layers from [0,
-last_layers]. Then you can append more layers by the normal SINGA model.</p>
-<pre><code class="hljs css language-python3">class Trans:
+<p>You also can append some layers to the end of the ONNX model to do
+transfer-learning. The <code>last_layers</code> accept a negative integer indicating the
+layer to cut off from. For example, <code>-1</code> means cut off after the final output(do
+not cut off any layer), <code>-2</code> means you cut off after the last second layer.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers):
-        <span class="hljs-literal">self</span>.sg_ir = sg_ir
-        <span class="hljs-literal">self</span>.last_layers = last_layers
-        <span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False)
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>):
-        <span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>]
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
+        <span class="hljs-keyword">self</span>.linear = layer.Linear(<span class="hljs-number">1000</span>, <span class="hljs-number">3</span>)
 
-autograd.training = False
-model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-# <span class="hljs-keyword">then</span> you training the model like normal
-</code></pre>
-<h2><a class="anchor" aria-hidden="true" id="a-full-example"></a><a href="#a-full-example" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>A Full Example</h2>
-<p>This part introduces the usage of SINGA ONNX by using the mnist example. In this
-section, the examples of how to export, load, inference, re-training, and
-transfer-learning the minist model are displayed. You can try this part
-<a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">here</a>.</p>
-<h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3>
-<p>Firstly, you need to import some necessary libraries and define some auxiliary
-functions for downloading and preprocessing the dataset:</p>
-<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
-<span class="hljs-keyword">import</span> urllib.request
-<span class="hljs-keyword">import</span> gzip
-<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
-<span class="hljs-keyword">import</span> codecs
-
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
-<span class="hljs-keyword">import</span> onnx
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
-    train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
-    train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
-    valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
-    valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
-    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
-        np.float32)
-    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
-        np.float32)
-    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
-        np.float32)
-    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
-        np.float32)
-    <span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>
-
-    download_dir = <span class="hljs-string">'/tmp/'</span>
-
-    name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
-    filename = os.path.join(download_dir, name)
-    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
-        print(<span class="hljs-string">"Downloading %s"</span> % url)
-        urllib.request.urlretrieve(url, filename)
-    <span class="hljs-keyword">return</span> filename
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
-            (length))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
-    <span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
-        num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
-            (length, <span class="hljs-number">1</span>, num_rows, num_cols))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
-    y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
-    n = y.shape[<span class="hljs-number">0</span>]
-    categorical = np.zeros((n, num_classes))
-    categorical[np.arange(n), y] = <span class="hljs-number">1</span>
-    categorical = categorical.astype(np.float32)
-    <span class="hljs-keyword">return</span> categorical
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3>
-<p>Then you can define a class called <strong>CNN</strong> to construct the mnist model which
-consists of several convolution, pooling, fully connection and relu layers. You
-can also define a function to calculate the <strong>accuracy</strong> of our result. Finally,
-you can define a <strong>train</strong> and a <strong>test</strong> function to handle the training and
-prediction process.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
-        self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
-        self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-        self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-        self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = self.conv1(x)
-        y = autograd.relu(y)
-        y = self.pooling1(y)
-        y = self.conv2(y)
-        y = autograd.relu(y)
-        y = self.pooling2(y)
-        y = autograd.flatten(y)
-        y = self.linear1(y)
-        y = autograd.relu(y)
-        y = self.linear2(y)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        <span class="hljs-comment"># cut off after the last third layer</span>
+        <span class="hljs-comment"># and append a linear layer</span>
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x, last_layers=-<span class="hljs-number">3</span>)[<span class="hljs-number">0</span>]
+        y = <span class="hljs-keyword">self</span>.linear(y)
         <span class="hljs-keyword">return</span> y
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
 
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
-    y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
-    t = np.argmax(target, axis=<span class="hljs-number">1</span>)
-    a = y == t
-    <span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
 
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
-          x,
-          y,
-          epochs=<span class="hljs-number">1</span>,
-          batch_size=<span class="hljs-number">64</span>,
-          dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = model.forward(x_batch)
-            <span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
-            <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"training completed"</span>)
-    <span class="hljs-keyword">return</span> x_batch, output_batch
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    result = <span class="hljs-number">0</span>
-    <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-        l_idx = b * batch_size
-        r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-        output_batch = model.forward(x_batch)
-        result += accuracy(tensor.to_numpy(output_batch),
-                           tensor.to_numpy(target_batch))
-
-    print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3>
-<p>Now, you can train the mnist model and export its onnx model by calling the
-<strong>soonx.to_onnx</strong> function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
-    <span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])
-
-<span class="hljs-comment"># create device</span>
-dev = device.create_cuda_gpu()
-<span class="hljs-comment">#dev = device.get_default_device()</span>
-<span class="hljs-comment"># create model</span>
-model = CNN()
-<span class="hljs-comment"># load data</span>
-train_x, train_y, valid_x, valid_y = load_dataset()
-<span class="hljs-comment"># normalization</span>
-train_x = train_x / <span class="hljs-number">255</span>
-valid_x = valid_x / <span class="hljs-number">255</span>
-train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
-valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
-<span class="hljs-comment"># do training</span>
-autograd.training = <span class="hljs-literal">True</span>
-x, y = train(model, train_x, train_y, dev=dev)
-onnx_model = make_onnx(x, y)
-<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-<span class="hljs-comment"># Save the ONNX model</span>
-model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
-onnx.save(onnx_model, model_path)
-print(<span class="hljs-string">'The model is saved.'</span>)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
-<p>After you export the onnx model, you can find a file called <strong>mnist.onnx</strong> in
-the '/tmp' directory, this model, therefore, can be imported by other libraries.
-Now, if you want to import this onnx model into singa again and do the inference
-using the validation dataset, you can define a class called <strong>Infer</strong>, the
-forward function of Infer will be called by the test function to do inference
-for validation dataset. By the way, you should set the label of training to
-<strong>False</strong> to fix the gradient of autograd operators.</p>
-<p>When import the onnx model, you need to call <strong>onnx.load</strong> to load the onnx
-model firstly. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to
-parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a
-singa graph within it, and then you can run an step of inference by feeding
-input to its run function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad= <span class="hljs-literal">True</span>
-            sg_ir.tensor_map[idx] = tens
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>
-
-<span class="hljs-comment"># inference</span>
-autograd.training = <span class="hljs-literal">False</span>
-print(<span class="hljs-string">'The inference result is:'</span>)
-test(Infer(sg_ir), valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3>
-<p>Assume after import the model, you want to re-train the model again, we can
-define a function called <strong>re_train</strong>. Before we call this re_train function, we
-should set the label of training to <strong>True</strong> to make the autograde operators
-update their gradient. And after we finish the training, we set it as <strong>False</strong>
-again to call the test function doing inference.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    new_model = Infer(sg_ir)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = new_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"re-training completed"</span>)
-    <span class="hljs-keyword">return</span> new_model
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># re-training</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = re_train(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3>
-<p>Finally, if we want to do transfer-learning, we can define a function called
-<strong>Trans</strong> to append some layers after the onnx model. For demonstration, the
-code only appends several linear(fully connection) and relu after the onnx
-model. You can define a transfer_learning function to handle the training
-process of the transfer-learning model. And the label of training is the same as
-the previous one.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        <span class="hljs-keyword">return</span> y
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-            output_batch = trans_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"transfer-learning completed"</span>)
-    <span class="hljs-keyword">return</span> trans_mode
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># transfer-learning</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h2><a class="anchor" aria-hidden="true" id="onnx-model-zoo"></a><a href="#onnx-model-zoo" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX model zoo</h2>
 <p>The <a href="https://github.com/onnx/models">ONNX Model Zoo</a> is a collection of
@@ -567,6 +237,8 @@
 <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/shufflenet">ShuffleNet_V2</a></b></td><td><a href="https://arxiv.org/pdf/1707.01083.pdf">Simonyan et al.</a></td><td>Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6%</td><td>[<img src="https://colab.research.google.com/drive/19HfRu3YHP_H2z3BcZujVFRp23_J5XsuA?usp=sharing" alt="Open In Colab"></td></tr>
 </tbody>
 </table>
+<p>We also give some re-training examples by using VGG and ResNet, please check
+<code>examples/onnx/training</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="object-detection"></a><a href="#object-detection" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Object Detection</h3>
 <p>Object detection models detect the presence of multiple objects in an image and
 segment out areas of the image where the objects are detected.</p>
@@ -698,11 +370,13 @@
 <li><p>Empty tensor Empty tensor is illegal in SINGA.</p></li>
 </ul>
 <h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
-<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are three main
-class, <code>SingaFrontend</code> and <code>SingaBackend</code> and <code>SingaRep</code>. <code>SingaFrontend</code>
-translates a SINGA model to ONNX model; <code>SingaBackend</code> translates a ONNX model
-to <code>SingaRep</code> object which stores all SINGA operators and tensors(the tensor in
-this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run like a SINGA model.</p>
+<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are four main
+class, <code>SingaFrontend</code>, <code>SingaBackend</code>, <code>SingaRep</code> and <code>SONNXModel</code>.
+<code>SingaFrontend</code> translates a SINGA model to an ONNX model; <code>SingaBackend</code>
+translates an ONNX model to <code>SingaRep</code> object which stores all SINGA operators
+and tensors(the tensor in this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run
+like a SINGA model. <code>SONNXModel</code> inherits from <code>model.Model</code> which defines a
+unified API for SINGA.</p>
 <h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3>
 <p>The entry function of <code>SingaFrontend</code> is <code>singa_to_onnx_model</code> which also is
 called <code>to_onnx</code>. <code>singa_to_onnx_model</code> creates the ONNX model, and it also
@@ -724,40 +398,32 @@
 operators to be changed as bool type.</p>
 <h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3>
 <p>The entry function of <code>SingaBackend</code> is <code>prepare</code> which checks the version of
-ONNX model and call <code>_onnx_model_to_singa_net</code> then.</p>
-<p>The purpose of <code>_onnx_model_to_singa_net</code> is to get SINGA tensors and operators.
+ONNX model and call <code>_onnx_model_to_singa_ops</code> then.</p>
+<p>The purpose of <code>_onnx_model_to_singa_ops</code> is to get SINGA tensors and operators.
 The tensors are stored in a dictionary by their name in ONNX, and operators are
-stored in queue by the form of
-<code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>. For each
-operator, <code>name</code> is its ONNX node name; <code>op</code> is the ONNX node; <code>forward</code> is the
-SINGA operator's forward function; <code>handle</code> is prepared for some special
-operators such as Conv and Pooling which has <code>handle</code> object.</p>
-<p>The first step of <code>_onnx_model_to_singa_net</code> is to call <code>_init_graph_parameter</code>
-to get all tensors within the model. For trainable weights, it can init SINGA
-<code>Tensor</code> from <code>onnx_model.graph.initializer</code>. Please note, the weights may also
-be stored within graph's input or a ONNX node called <code>Constant</code>, SINGA can also
-handle these.</p>
-<p>Though all weights are stored within ONNX model, the input of the model is
-unknown but its shape and type. So SINGA support two ways to init input, 1,
-generate random tensor by its shape and type, 2, allow the user to assign the
-input. The first way works fine for most models, however, for some model such as
-bert, the indices of matrix cannot be random generated otherwise it will incurs
-errors.</p>
-<p>Then, <code>_onnx_model_to_singa_net</code> iterators all nodes within ONNX graph to
-translate it to SIGNA operators. Also, <code>_rename_operators</code> defines the operators
-name mapping between SINGA and ONNX. <code>_special_operators</code> defines which function
-to be used to translate the operator. <code>_run_node</code> runs the generated SINGA model
-by its input tensors and store its output tensors for being used by later
-operators.</p>
-<p>This class finally return a <code>SingaRep</code> object and stores all SINGA tensors and
-operators within it.</p>
+stored in queue by the form of <code>namedtuple('SingaOps', ['node', 'operator'])</code>.
+For each operator, <code>node</code> is an instance from OnnxNode which is defined to store
+some basic information for an ONNX node; <code>operator</code> is the SINGA operator's
+forward function;</p>
+<p>The first step of <code>_onnx_model_to_singa_ops</code> has four steps, the first one is to
+call <code>_parse_graph_params</code> to get all tensors stored as <code>params</code>. Then call
+<code>_parse_graph_inputs_outputs</code> to get all input and output information stores as
+<code>inputs</code> and <code>outputs</code>. Finally, it iterators all nodes within the ONNX graph
+and parses it by <code>_onnx_node_to_singa_op</code> as SIGNA operators or layers and store
+them as <code>outputs</code>. Some weights are stored within an ONNX node called
+<code>Constant</code>, SONNX can handle them by <code>_onnx_constant_to_np</code> to store it into
+<code>params</code>.</p>
+<p>This class finally return a <code>SingaRep</code> object and stores above <code>params</code>,
+<code>inputs</code>, <code>outputs</code>, <code>layers</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3>
 <p><code>SingaBackend</code> stores all SINGA tensors and operators. <code>run</code> accepts the input
-of the model and run the SINGA operators one by one following the operators
-queue. The user can use <code>last_layers</code> to decide to run the model till the last
-few layers. Set <code>all_outputs</code> as <code>False</code> to get only the final output, <code>True</code> to
-also get all the intermediate output.</p>
-</span></div></article></div><div class="docLastUpdate"><em>Last updated on 20/09/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/next/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/next/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#a-full-example">A Full Example</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
+of the model and runs the SINGA operators one by one following the operators'
+queue. The user can use <code>last_layers</code> to cut off the model after the last few
+layers.</p>
+<h3><a class="anchor" aria-hidden="true" id="sonnxmodel"></a><a href="#sonnxmodel" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SONNXModel</h3>
+<p><code>SONNXModel</code> inherits from <code>sonnx.SONNXModel</code> and implements the method
+<code>forward</code> to provide a unified API with other SINGA models.</p>
+</span></div></article></div><div class="docLastUpdate"><em>Last updated on 25/11/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/next/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/next/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li><li><a href="#sonnxmodel">SONNXModel</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
    The Apache Software Foundation. All rights reserved.
    Apache SINGA, Apache, the Apache feather logo, and
    the Apache SINGA project logos are trademarks of The
diff --git a/content/docs/next/onnx/index.html b/content/docs/next/onnx/index.html
index 62b63bb..425ce07 100644
--- a/content/docs/next/onnx/index.html
+++ b/content/docs/next/onnx/index.html
@@ -78,60 +78,47 @@
 </table>
 <h2><a class="anchor" aria-hidden="true" id="general-usage"></a><a href="#general-usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>General usage</h2>
 <h3><a class="anchor" aria-hidden="true" id="loading-an-onnx-model-into-singa"></a><a href="#loading-an-onnx-model-into-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Loading an ONNX Model into SINGA</h3>
-<p>After loading an ONNX model from disk by <code>onnx.load</code>, you need to update the
-model's batchsize, since for most models, they use a placeholder to represent
-its batchsize. We give an example here, as <code>update_batch_size</code>. You only need to
-update the batchsize of input and output, the shape of internal tensors will be
-inferred automatically.</p>
-<p>Then, you can prepare the SINGA model by using <code>sonnx.prepare</code>. This function
-iterates and translates all the nodes within the ONNX model's graph into SINGA
-operators, loads all stored weights and infers each intermediate tensor's shape.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx
-from singa <span class="hljs-built_in">import</span> device
-from singa <span class="hljs-built_in">import</span> sonnx
+<p>After loading an ONNX model from disk by <code>onnx.load</code>, You only need to update
+the batch-size of input using <code>tensor.PlaceHolder</code> after SINGA v3.0, the shape
+of internal tensors will be inferred automatically.</p>
+<p>Then, you should define a class inheriting from <code>sonnx.SONNXModel</code> and implement
+two methods <code>forward</code> for forward work and <code>train_one_batch</code> for training work.
+After you call <code>model.compile</code>, the SONNX iterates and translates all the nodes
+within the ONNX model's graph into SINGA operators, loads all stored weights and
+infers each intermediate tensor's shape.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-keyword">import</span> onnx
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
 
-<span class="hljs-comment"># if the input has multiple tensors? can put this function inside prepare()?</span>
-def update_batch_size(onnx_model, batch_size):
-    <span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>]
-    model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    <span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>]
-    model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    return onnx_model
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span><span class="hljs-params">(sonnx.SONNXModel)</span>:</span>
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, onnx_model)</span>:</span>
+        super(MyModel, self).__init__(onnx_model)
 
-<span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
-<span class="hljs-attr">onnx_model</span> = onnx.load(model_path)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, *x)</span>:</span>
+        y = super(MyModel, self).forward(*x)
+        <span class="hljs-comment"># Since SINGA model returns the output as a list,</span>
+        <span class="hljs-comment"># if there is only one output,</span>
+        <span class="hljs-comment"># you just need to take the first element.</span>
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-<span class="hljs-comment"># set batch size</span>
-<span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(self, x, y)</span>:</span>
+        <span class="hljs-keyword">pass</span>
 
-<span class="hljs-comment"># convert onnx graph nodes into SINGA operators</span>
-<span class="hljs-attr">dev</span> = device.create_cuda_gpu()
-<span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span>
+model_path = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
+onnx_model = onnx.load(model_path)
+
+<span class="hljs-comment"># convert onnx model into SINGA model</span>
+dev = device.create_cuda_gpu()
+x = tensor.PlaceHolder(INPUT.shape, device=dev)
+model = MyModel(onnx_model)
+model.compile([x], is_train=<span class="hljs-literal">False</span>, use_graph=<span class="hljs-literal">True</span>, sequential=<span class="hljs-literal">True</span>)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="inference-singa-model"></a><a href="#inference-singa-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference SINGA model</h3>
-<p>Once the model is created, you can do inference by calling <code>sg_ir.run</code>. The
-input and output must be SINGA <code>Tensor</code> instances. Since SINGA model returns the
-output as a list, if there is only one output, you just need to take the first
-element from the output.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-comment"># can warp the following code in prepare()</span>
-<span class="hljs-comment"># and provide a flag training=True/False?</span>
-
-<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
-        <span class="hljs-keyword">self</span>.sg_ir = sg_ir
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
-
-
-data = get_dataset()
-x = tensor.Tensor(device=dev, data=data)
-
-model = Infer(sg_ir)
-y = model.forward(x)
+<p>Once the model is created, you can do inference by calling <code>model.forward</code>. The
+input and output must be SINGA <code>Tensor</code> instances.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-attr">x</span> = tensor.Tensor(device=dev, data=INPUT)
+<span class="hljs-attr">y</span> = model.forward(x)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="saving-singa-model-into-onnx-format"></a><a href="#saving-singa-model-into-onnx-format" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Saving SINGA model into ONNX Format</h3>
 <p>Given the input tensors and the output tensors generated by the operators the
@@ -142,410 +129,93 @@
 sonnx.to_onnx([<span class="hljs-symbol">x</span>], [<span class="hljs-symbol">y</span>])
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="re-training-an-onnx-model"></a><a href="#re-training-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training an ONNX model</h3>
-<p>To train (or refine) an ONNX model using SINGA, you need to set the internal
-tensors to be trainable</p>
-<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
+<p>To train (or refine) an ONNX model using SINGA, you need to implement the
+<code>train_one_batch</code> from <code>sonnx.SONNXModel</code> and mark the <code>is_train=True</code> when
+calling <code>model.compile</code>.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-comment">## can wrap these codes in sonnx?</span>
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad = <span class="hljs-literal">True</span>
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
 
-autograd.training = <span class="hljs-literal">False</span>
-model = Infer(sg_ir)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x)
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-autograd.training = <span class="hljs-literal">True</span>
-<span class="hljs-comment"># then you training the model like normal</span>
-<span class="hljs-comment"># give more details??</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
+
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
+
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="transfer-learning-an-onnx-model"></a><a href="#transfer-learning-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer-learning an ONNX model</h3>
-<p>You also can append some layers to the end of ONNX model to do
-transfer-learning. The <code>last_layers</code> means you cut the ONNX layers from [0,
-last_layers]. Then you can append more layers by the normal SINGA model.</p>
-<pre><code class="hljs css language-python3">class Trans:
+<p>You also can append some layers to the end of the ONNX model to do
+transfer-learning. The <code>last_layers</code> accept a negative integer indicating the
+layer to cut off from. For example, <code>-1</code> means cut off after the final output(do
+not cut off any layer), <code>-2</code> means you cut off after the last second layer.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers):
-        <span class="hljs-literal">self</span>.sg_ir = sg_ir
-        <span class="hljs-literal">self</span>.last_layers = last_layers
-        <span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False)
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>):
-        <span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>]
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
+        <span class="hljs-keyword">self</span>.linear = layer.Linear(<span class="hljs-number">1000</span>, <span class="hljs-number">3</span>)
 
-autograd.training = False
-model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-# <span class="hljs-keyword">then</span> you training the model like normal
-</code></pre>
-<h2><a class="anchor" aria-hidden="true" id="a-full-example"></a><a href="#a-full-example" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>A Full Example</h2>
-<p>This part introduces the usage of SINGA ONNX by using the mnist example. In this
-section, the examples of how to export, load, inference, re-training, and
-transfer-learning the minist model are displayed. You can try this part
-<a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">here</a>.</p>
-<h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3>
-<p>Firstly, you need to import some necessary libraries and define some auxiliary
-functions for downloading and preprocessing the dataset:</p>
-<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
-<span class="hljs-keyword">import</span> urllib.request
-<span class="hljs-keyword">import</span> gzip
-<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
-<span class="hljs-keyword">import</span> codecs
-
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
-<span class="hljs-keyword">import</span> onnx
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
-    train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
-    train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
-    valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
-    valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
-    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
-        np.float32)
-    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
-        np.float32)
-    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
-        np.float32)
-    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
-        np.float32)
-    <span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>
-
-    download_dir = <span class="hljs-string">'/tmp/'</span>
-
-    name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
-    filename = os.path.join(download_dir, name)
-    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
-        print(<span class="hljs-string">"Downloading %s"</span> % url)
-        urllib.request.urlretrieve(url, filename)
-    <span class="hljs-keyword">return</span> filename
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
-            (length))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
-    <span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
-        num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
-            (length, <span class="hljs-number">1</span>, num_rows, num_cols))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
-    y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
-    n = y.shape[<span class="hljs-number">0</span>]
-    categorical = np.zeros((n, num_classes))
-    categorical[np.arange(n), y] = <span class="hljs-number">1</span>
-    categorical = categorical.astype(np.float32)
-    <span class="hljs-keyword">return</span> categorical
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3>
-<p>Then you can define a class called <strong>CNN</strong> to construct the mnist model which
-consists of several convolution, pooling, fully connection and relu layers. You
-can also define a function to calculate the <strong>accuracy</strong> of our result. Finally,
-you can define a <strong>train</strong> and a <strong>test</strong> function to handle the training and
-prediction process.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
-        self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
-        self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-        self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-        self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = self.conv1(x)
-        y = autograd.relu(y)
-        y = self.pooling1(y)
-        y = self.conv2(y)
-        y = autograd.relu(y)
-        y = self.pooling2(y)
-        y = autograd.flatten(y)
-        y = self.linear1(y)
-        y = autograd.relu(y)
-        y = self.linear2(y)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        <span class="hljs-comment"># cut off after the last third layer</span>
+        <span class="hljs-comment"># and append a linear layer</span>
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x, last_layers=-<span class="hljs-number">3</span>)[<span class="hljs-number">0</span>]
+        y = <span class="hljs-keyword">self</span>.linear(y)
         <span class="hljs-keyword">return</span> y
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
 
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
-    y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
-    t = np.argmax(target, axis=<span class="hljs-number">1</span>)
-    a = y == t
-    <span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
 
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
-          x,
-          y,
-          epochs=<span class="hljs-number">1</span>,
-          batch_size=<span class="hljs-number">64</span>,
-          dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = model.forward(x_batch)
-            <span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
-            <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"training completed"</span>)
-    <span class="hljs-keyword">return</span> x_batch, output_batch
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    result = <span class="hljs-number">0</span>
-    <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-        l_idx = b * batch_size
-        r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-        output_batch = model.forward(x_batch)
-        result += accuracy(tensor.to_numpy(output_batch),
-                           tensor.to_numpy(target_batch))
-
-    print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3>
-<p>Now, you can train the mnist model and export its onnx model by calling the
-<strong>soonx.to_onnx</strong> function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
-    <span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])
-
-<span class="hljs-comment"># create device</span>
-dev = device.create_cuda_gpu()
-<span class="hljs-comment">#dev = device.get_default_device()</span>
-<span class="hljs-comment"># create model</span>
-model = CNN()
-<span class="hljs-comment"># load data</span>
-train_x, train_y, valid_x, valid_y = load_dataset()
-<span class="hljs-comment"># normalization</span>
-train_x = train_x / <span class="hljs-number">255</span>
-valid_x = valid_x / <span class="hljs-number">255</span>
-train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
-valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
-<span class="hljs-comment"># do training</span>
-autograd.training = <span class="hljs-literal">True</span>
-x, y = train(model, train_x, train_y, dev=dev)
-onnx_model = make_onnx(x, y)
-<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-<span class="hljs-comment"># Save the ONNX model</span>
-model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
-onnx.save(onnx_model, model_path)
-print(<span class="hljs-string">'The model is saved.'</span>)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
-<p>After you export the onnx model, you can find a file called <strong>mnist.onnx</strong> in
-the '/tmp' directory, this model, therefore, can be imported by other libraries.
-Now, if you want to import this onnx model into singa again and do the inference
-using the validation dataset, you can define a class called <strong>Infer</strong>, the
-forward function of Infer will be called by the test function to do inference
-for validation dataset. By the way, you should set the label of training to
-<strong>False</strong> to fix the gradient of autograd operators.</p>
-<p>When import the onnx model, you need to call <strong>onnx.load</strong> to load the onnx
-model firstly. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to
-parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a
-singa graph within it, and then you can run an step of inference by feeding
-input to its run function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad= <span class="hljs-literal">True</span>
-            sg_ir.tensor_map[idx] = tens
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>
-
-<span class="hljs-comment"># inference</span>
-autograd.training = <span class="hljs-literal">False</span>
-print(<span class="hljs-string">'The inference result is:'</span>)
-test(Infer(sg_ir), valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3>
-<p>Assume after import the model, you want to re-train the model again, we can
-define a function called <strong>re_train</strong>. Before we call this re_train function, we
-should set the label of training to <strong>True</strong> to make the autograde operators
-update their gradient. And after we finish the training, we set it as <strong>False</strong>
-again to call the test function doing inference.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    new_model = Infer(sg_ir)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = new_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"re-training completed"</span>)
-    <span class="hljs-keyword">return</span> new_model
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># re-training</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = re_train(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3>
-<p>Finally, if we want to do transfer-learning, we can define a function called
-<strong>Trans</strong> to append some layers after the onnx model. For demonstration, the
-code only appends several linear(fully connection) and relu after the onnx
-model. You can define a transfer_learning function to handle the training
-process of the transfer-learning model. And the label of training is the same as
-the previous one.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        <span class="hljs-keyword">return</span> y
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-            output_batch = trans_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"transfer-learning completed"</span>)
-    <span class="hljs-keyword">return</span> trans_mode
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># transfer-learning</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h2><a class="anchor" aria-hidden="true" id="onnx-model-zoo"></a><a href="#onnx-model-zoo" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX model zoo</h2>
 <p>The <a href="https://github.com/onnx/models">ONNX Model Zoo</a> is a collection of
@@ -567,6 +237,8 @@
 <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/shufflenet">ShuffleNet_V2</a></b></td><td><a href="https://arxiv.org/pdf/1707.01083.pdf">Simonyan et al.</a></td><td>Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6%</td><td>[<img src="https://colab.research.google.com/drive/19HfRu3YHP_H2z3BcZujVFRp23_J5XsuA?usp=sharing" alt="Open In Colab"></td></tr>
 </tbody>
 </table>
+<p>We also give some re-training examples by using VGG and ResNet, please check
+<code>examples/onnx/training</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="object-detection"></a><a href="#object-detection" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Object Detection</h3>
 <p>Object detection models detect the presence of multiple objects in an image and
 segment out areas of the image where the objects are detected.</p>
@@ -698,11 +370,13 @@
 <li><p>Empty tensor Empty tensor is illegal in SINGA.</p></li>
 </ul>
 <h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
-<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are three main
-class, <code>SingaFrontend</code> and <code>SingaBackend</code> and <code>SingaRep</code>. <code>SingaFrontend</code>
-translates a SINGA model to ONNX model; <code>SingaBackend</code> translates a ONNX model
-to <code>SingaRep</code> object which stores all SINGA operators and tensors(the tensor in
-this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run like a SINGA model.</p>
+<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are four main
+class, <code>SingaFrontend</code>, <code>SingaBackend</code>, <code>SingaRep</code> and <code>SONNXModel</code>.
+<code>SingaFrontend</code> translates a SINGA model to an ONNX model; <code>SingaBackend</code>
+translates an ONNX model to <code>SingaRep</code> object which stores all SINGA operators
+and tensors(the tensor in this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run
+like a SINGA model. <code>SONNXModel</code> inherits from <code>model.Model</code> which defines a
+unified API for SINGA.</p>
 <h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3>
 <p>The entry function of <code>SingaFrontend</code> is <code>singa_to_onnx_model</code> which also is
 called <code>to_onnx</code>. <code>singa_to_onnx_model</code> creates the ONNX model, and it also
@@ -724,40 +398,32 @@
 operators to be changed as bool type.</p>
 <h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3>
 <p>The entry function of <code>SingaBackend</code> is <code>prepare</code> which checks the version of
-ONNX model and call <code>_onnx_model_to_singa_net</code> then.</p>
-<p>The purpose of <code>_onnx_model_to_singa_net</code> is to get SINGA tensors and operators.
+ONNX model and call <code>_onnx_model_to_singa_ops</code> then.</p>
+<p>The purpose of <code>_onnx_model_to_singa_ops</code> is to get SINGA tensors and operators.
 The tensors are stored in a dictionary by their name in ONNX, and operators are
-stored in queue by the form of
-<code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>. For each
-operator, <code>name</code> is its ONNX node name; <code>op</code> is the ONNX node; <code>forward</code> is the
-SINGA operator's forward function; <code>handle</code> is prepared for some special
-operators such as Conv and Pooling which has <code>handle</code> object.</p>
-<p>The first step of <code>_onnx_model_to_singa_net</code> is to call <code>_init_graph_parameter</code>
-to get all tensors within the model. For trainable weights, it can init SINGA
-<code>Tensor</code> from <code>onnx_model.graph.initializer</code>. Please note, the weights may also
-be stored within graph's input or a ONNX node called <code>Constant</code>, SINGA can also
-handle these.</p>
-<p>Though all weights are stored within ONNX model, the input of the model is
-unknown but its shape and type. So SINGA support two ways to init input, 1,
-generate random tensor by its shape and type, 2, allow the user to assign the
-input. The first way works fine for most models, however, for some model such as
-bert, the indices of matrix cannot be random generated otherwise it will incurs
-errors.</p>
-<p>Then, <code>_onnx_model_to_singa_net</code> iterators all nodes within ONNX graph to
-translate it to SIGNA operators. Also, <code>_rename_operators</code> defines the operators
-name mapping between SINGA and ONNX. <code>_special_operators</code> defines which function
-to be used to translate the operator. <code>_run_node</code> runs the generated SINGA model
-by its input tensors and store its output tensors for being used by later
-operators.</p>
-<p>This class finally return a <code>SingaRep</code> object and stores all SINGA tensors and
-operators within it.</p>
+stored in queue by the form of <code>namedtuple('SingaOps', ['node', 'operator'])</code>.
+For each operator, <code>node</code> is an instance from OnnxNode which is defined to store
+some basic information for an ONNX node; <code>operator</code> is the SINGA operator's
+forward function;</p>
+<p>The first step of <code>_onnx_model_to_singa_ops</code> has four steps, the first one is to
+call <code>_parse_graph_params</code> to get all tensors stored as <code>params</code>. Then call
+<code>_parse_graph_inputs_outputs</code> to get all input and output information stores as
+<code>inputs</code> and <code>outputs</code>. Finally, it iterators all nodes within the ONNX graph
+and parses it by <code>_onnx_node_to_singa_op</code> as SIGNA operators or layers and store
+them as <code>outputs</code>. Some weights are stored within an ONNX node called
+<code>Constant</code>, SONNX can handle them by <code>_onnx_constant_to_np</code> to store it into
+<code>params</code>.</p>
+<p>This class finally return a <code>SingaRep</code> object and stores above <code>params</code>,
+<code>inputs</code>, <code>outputs</code>, <code>layers</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3>
 <p><code>SingaBackend</code> stores all SINGA tensors and operators. <code>run</code> accepts the input
-of the model and run the SINGA operators one by one following the operators
-queue. The user can use <code>last_layers</code> to decide to run the model till the last
-few layers. Set <code>all_outputs</code> as <code>False</code> to get only the final output, <code>True</code> to
-also get all the intermediate output.</p>
-</span></div></article></div><div class="docLastUpdate"><em>Last updated on 20/09/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/next/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/next/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#a-full-example">A Full Example</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
+of the model and runs the SINGA operators one by one following the operators'
+queue. The user can use <code>last_layers</code> to cut off the model after the last few
+layers.</p>
+<h3><a class="anchor" aria-hidden="true" id="sonnxmodel"></a><a href="#sonnxmodel" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SONNXModel</h3>
+<p><code>SONNXModel</code> inherits from <code>sonnx.SONNXModel</code> and implements the method
+<code>forward</code> to provide a unified API with other SINGA models.</p>
+</span></div></article></div><div class="docLastUpdate"><em>Last updated on 25/11/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/next/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/next/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li><li><a href="#sonnxmodel">SONNXModel</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
    The Apache Software Foundation. All rights reserved.
    Apache SINGA, Apache, the Apache feather logo, and
    the Apache SINGA project logos are trademarks of The
diff --git a/content/docs/onnx.html b/content/docs/onnx.html
index eeefc5a..5b7ba63 100644
--- a/content/docs/onnx.html
+++ b/content/docs/onnx.html
@@ -78,60 +78,47 @@
 </table>
 <h2><a class="anchor" aria-hidden="true" id="general-usage"></a><a href="#general-usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>General usage</h2>
 <h3><a class="anchor" aria-hidden="true" id="loading-an-onnx-model-into-singa"></a><a href="#loading-an-onnx-model-into-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Loading an ONNX Model into SINGA</h3>
-<p>After loading an ONNX model from disk by <code>onnx.load</code>, you need to update the
-model's batchsize, since for most models, they use a placeholder to represent
-its batchsize. We give an example here, as <code>update_batch_size</code>. You only need to
-update the batchsize of input and output, the shape of internal tensors will be
-inferred automatically.</p>
-<p>Then, you can prepare the SINGA model by using <code>sonnx.prepare</code>. This function
-iterates and translates all the nodes within the ONNX model's graph into SINGA
-operators, loads all stored weights and infers each intermediate tensor's shape.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx
-from singa <span class="hljs-built_in">import</span> device
-from singa <span class="hljs-built_in">import</span> sonnx
+<p>After loading an ONNX model from disk by <code>onnx.load</code>, You only need to update
+the batch-size of input using <code>tensor.PlaceHolder</code> after SINGA v3.0, the shape
+of internal tensors will be inferred automatically.</p>
+<p>Then, you should define a class inheriting from <code>sonnx.SONNXModel</code> and implement
+two methods <code>forward</code> for forward work and <code>train_one_batch</code> for training work.
+After you call <code>model.compile</code>, the SONNX iterates and translates all the nodes
+within the ONNX model's graph into SINGA operators, loads all stored weights and
+infers each intermediate tensor's shape.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-keyword">import</span> onnx
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
 
-<span class="hljs-comment"># if the input has multiple tensors? can put this function inside prepare()?</span>
-def update_batch_size(onnx_model, batch_size):
-    <span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>]
-    model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    <span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>]
-    model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    return onnx_model
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span><span class="hljs-params">(sonnx.SONNXModel)</span>:</span>
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, onnx_model)</span>:</span>
+        super(MyModel, self).__init__(onnx_model)
 
-<span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
-<span class="hljs-attr">onnx_model</span> = onnx.load(model_path)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, *x)</span>:</span>
+        y = super(MyModel, self).forward(*x)
+        <span class="hljs-comment"># Since SINGA model returns the output as a list,</span>
+        <span class="hljs-comment"># if there is only one output,</span>
+        <span class="hljs-comment"># you just need to take the first element.</span>
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-<span class="hljs-comment"># set batch size</span>
-<span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(self, x, y)</span>:</span>
+        <span class="hljs-keyword">pass</span>
 
-<span class="hljs-comment"># convert onnx graph nodes into SINGA operators</span>
-<span class="hljs-attr">dev</span> = device.create_cuda_gpu()
-<span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span>
+model_path = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
+onnx_model = onnx.load(model_path)
+
+<span class="hljs-comment"># convert onnx model into SINGA model</span>
+dev = device.create_cuda_gpu()
+x = tensor.PlaceHolder(INPUT.shape, device=dev)
+model = MyModel(onnx_model)
+model.compile([x], is_train=<span class="hljs-literal">False</span>, use_graph=<span class="hljs-literal">True</span>, sequential=<span class="hljs-literal">True</span>)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="inference-singa-model"></a><a href="#inference-singa-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference SINGA model</h3>
-<p>Once the model is created, you can do inference by calling <code>sg_ir.run</code>. The
-input and output must be SINGA <code>Tensor</code> instances. Since SINGA model returns the
-output as a list, if there is only one output, you just need to take the first
-element from the output.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-comment"># can warp the following code in prepare()</span>
-<span class="hljs-comment"># and provide a flag training=True/False?</span>
-
-<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
-        <span class="hljs-keyword">self</span>.sg_ir = sg_ir
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
-
-
-data = get_dataset()
-x = tensor.Tensor(device=dev, data=data)
-
-model = Infer(sg_ir)
-y = model.forward(x)
+<p>Once the model is created, you can do inference by calling <code>model.forward</code>. The
+input and output must be SINGA <code>Tensor</code> instances.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-attr">x</span> = tensor.Tensor(device=dev, data=INPUT)
+<span class="hljs-attr">y</span> = model.forward(x)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="saving-singa-model-into-onnx-format"></a><a href="#saving-singa-model-into-onnx-format" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Saving SINGA model into ONNX Format</h3>
 <p>Given the input tensors and the output tensors generated by the operators the
@@ -142,410 +129,93 @@
 sonnx.to_onnx([<span class="hljs-symbol">x</span>], [<span class="hljs-symbol">y</span>])
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="re-training-an-onnx-model"></a><a href="#re-training-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training an ONNX model</h3>
-<p>To train (or refine) an ONNX model using SINGA, you need to set the internal
-tensors to be trainable</p>
-<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
+<p>To train (or refine) an ONNX model using SINGA, you need to implement the
+<code>train_one_batch</code> from <code>sonnx.SONNXModel</code> and mark the <code>is_train=True</code> when
+calling <code>model.compile</code>.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-comment">## can wrap these codes in sonnx?</span>
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad = <span class="hljs-literal">True</span>
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
 
-autograd.training = <span class="hljs-literal">False</span>
-model = Infer(sg_ir)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x)
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-autograd.training = <span class="hljs-literal">True</span>
-<span class="hljs-comment"># then you training the model like normal</span>
-<span class="hljs-comment"># give more details??</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
+
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
+
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="transfer-learning-an-onnx-model"></a><a href="#transfer-learning-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer-learning an ONNX model</h3>
-<p>You also can append some layers to the end of ONNX model to do
-transfer-learning. The <code>last_layers</code> means you cut the ONNX layers from [0,
-last_layers]. Then you can append more layers by the normal SINGA model.</p>
-<pre><code class="hljs css language-python3">class Trans:
+<p>You also can append some layers to the end of the ONNX model to do
+transfer-learning. The <code>last_layers</code> accept a negative integer indicating the
+layer to cut off from. For example, <code>-1</code> means cut off after the final output(do
+not cut off any layer), <code>-2</code> means you cut off after the last second layer.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers):
-        <span class="hljs-literal">self</span>.sg_ir = sg_ir
-        <span class="hljs-literal">self</span>.last_layers = last_layers
-        <span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False)
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>):
-        <span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>]
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
+        <span class="hljs-keyword">self</span>.linear = layer.Linear(<span class="hljs-number">1000</span>, <span class="hljs-number">3</span>)
 
-autograd.training = False
-model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-# <span class="hljs-keyword">then</span> you training the model like normal
-</code></pre>
-<h2><a class="anchor" aria-hidden="true" id="a-full-example"></a><a href="#a-full-example" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>A Full Example</h2>
-<p>This part introduces the usage of SINGA ONNX by using the mnist example. In this
-section, the examples of how to export, load, inference, re-training, and
-transfer-learning the minist model are displayed. You can try this part
-<a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">here</a>.</p>
-<h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3>
-<p>Firstly, you need to import some necessary libraries and define some auxiliary
-functions for downloading and preprocessing the dataset:</p>
-<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
-<span class="hljs-keyword">import</span> urllib.request
-<span class="hljs-keyword">import</span> gzip
-<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
-<span class="hljs-keyword">import</span> codecs
-
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
-<span class="hljs-keyword">import</span> onnx
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
-    train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
-    train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
-    valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
-    valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
-    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
-        np.float32)
-    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
-        np.float32)
-    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
-        np.float32)
-    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
-        np.float32)
-    <span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>
-
-    download_dir = <span class="hljs-string">'/tmp/'</span>
-
-    name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
-    filename = os.path.join(download_dir, name)
-    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
-        print(<span class="hljs-string">"Downloading %s"</span> % url)
-        urllib.request.urlretrieve(url, filename)
-    <span class="hljs-keyword">return</span> filename
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
-            (length))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
-    <span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
-        num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
-            (length, <span class="hljs-number">1</span>, num_rows, num_cols))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
-    y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
-    n = y.shape[<span class="hljs-number">0</span>]
-    categorical = np.zeros((n, num_classes))
-    categorical[np.arange(n), y] = <span class="hljs-number">1</span>
-    categorical = categorical.astype(np.float32)
-    <span class="hljs-keyword">return</span> categorical
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3>
-<p>Then you can define a class called <strong>CNN</strong> to construct the mnist model which
-consists of several convolution, pooling, fully connection and relu layers. You
-can also define a function to calculate the <strong>accuracy</strong> of our result. Finally,
-you can define a <strong>train</strong> and a <strong>test</strong> function to handle the training and
-prediction process.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
-        self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
-        self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-        self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-        self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = self.conv1(x)
-        y = autograd.relu(y)
-        y = self.pooling1(y)
-        y = self.conv2(y)
-        y = autograd.relu(y)
-        y = self.pooling2(y)
-        y = autograd.flatten(y)
-        y = self.linear1(y)
-        y = autograd.relu(y)
-        y = self.linear2(y)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        <span class="hljs-comment"># cut off after the last third layer</span>
+        <span class="hljs-comment"># and append a linear layer</span>
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x, last_layers=-<span class="hljs-number">3</span>)[<span class="hljs-number">0</span>]
+        y = <span class="hljs-keyword">self</span>.linear(y)
         <span class="hljs-keyword">return</span> y
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
 
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
-    y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
-    t = np.argmax(target, axis=<span class="hljs-number">1</span>)
-    a = y == t
-    <span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
 
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
-          x,
-          y,
-          epochs=<span class="hljs-number">1</span>,
-          batch_size=<span class="hljs-number">64</span>,
-          dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = model.forward(x_batch)
-            <span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
-            <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"training completed"</span>)
-    <span class="hljs-keyword">return</span> x_batch, output_batch
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    result = <span class="hljs-number">0</span>
-    <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-        l_idx = b * batch_size
-        r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-        output_batch = model.forward(x_batch)
-        result += accuracy(tensor.to_numpy(output_batch),
-                           tensor.to_numpy(target_batch))
-
-    print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3>
-<p>Now, you can train the mnist model and export its onnx model by calling the
-<strong>soonx.to_onnx</strong> function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
-    <span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])
-
-<span class="hljs-comment"># create device</span>
-dev = device.create_cuda_gpu()
-<span class="hljs-comment">#dev = device.get_default_device()</span>
-<span class="hljs-comment"># create model</span>
-model = CNN()
-<span class="hljs-comment"># load data</span>
-train_x, train_y, valid_x, valid_y = load_dataset()
-<span class="hljs-comment"># normalization</span>
-train_x = train_x / <span class="hljs-number">255</span>
-valid_x = valid_x / <span class="hljs-number">255</span>
-train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
-valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
-<span class="hljs-comment"># do training</span>
-autograd.training = <span class="hljs-literal">True</span>
-x, y = train(model, train_x, train_y, dev=dev)
-onnx_model = make_onnx(x, y)
-<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-<span class="hljs-comment"># Save the ONNX model</span>
-model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
-onnx.save(onnx_model, model_path)
-print(<span class="hljs-string">'The model is saved.'</span>)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
-<p>After you export the onnx model, you can find a file called <strong>mnist.onnx</strong> in
-the '/tmp' directory, this model, therefore, can be imported by other libraries.
-Now, if you want to import this onnx model into singa again and do the inference
-using the validation dataset, you can define a class called <strong>Infer</strong>, the
-forward function of Infer will be called by the test function to do inference
-for validation dataset. By the way, you should set the label of training to
-<strong>False</strong> to fix the gradient of autograd operators.</p>
-<p>When import the onnx model, you need to call <strong>onnx.load</strong> to load the onnx
-model firstly. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to
-parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a
-singa graph within it, and then you can run an step of inference by feeding
-input to its run function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad= <span class="hljs-literal">True</span>
-            sg_ir.tensor_map[idx] = tens
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>
-
-<span class="hljs-comment"># inference</span>
-autograd.training = <span class="hljs-literal">False</span>
-print(<span class="hljs-string">'The inference result is:'</span>)
-test(Infer(sg_ir), valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3>
-<p>Assume after import the model, you want to re-train the model again, we can
-define a function called <strong>re_train</strong>. Before we call this re_train function, we
-should set the label of training to <strong>True</strong> to make the autograde operators
-update their gradient. And after we finish the training, we set it as <strong>False</strong>
-again to call the test function doing inference.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    new_model = Infer(sg_ir)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = new_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"re-training completed"</span>)
-    <span class="hljs-keyword">return</span> new_model
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># re-training</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = re_train(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3>
-<p>Finally, if we want to do transfer-learning, we can define a function called
-<strong>Trans</strong> to append some layers after the onnx model. For demonstration, the
-code only appends several linear(fully connection) and relu after the onnx
-model. You can define a transfer_learning function to handle the training
-process of the transfer-learning model. And the label of training is the same as
-the previous one.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        <span class="hljs-keyword">return</span> y
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-            output_batch = trans_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"transfer-learning completed"</span>)
-    <span class="hljs-keyword">return</span> trans_mode
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># transfer-learning</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h2><a class="anchor" aria-hidden="true" id="onnx-model-zoo"></a><a href="#onnx-model-zoo" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX model zoo</h2>
 <p>The <a href="https://github.com/onnx/models">ONNX Model Zoo</a> is a collection of
@@ -567,6 +237,8 @@
 <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/shufflenet">ShuffleNet_V2</a></b></td><td><a href="https://arxiv.org/pdf/1707.01083.pdf">Simonyan et al.</a></td><td>Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6%</td><td>[<img src="https://colab.research.google.com/drive/19HfRu3YHP_H2z3BcZujVFRp23_J5XsuA?usp=sharing" alt="Open In Colab"></td></tr>
 </tbody>
 </table>
+<p>We also give some re-training examples by using VGG and ResNet, please check
+<code>examples/onnx/training</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="object-detection"></a><a href="#object-detection" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Object Detection</h3>
 <p>Object detection models detect the presence of multiple objects in an image and
 segment out areas of the image where the objects are detected.</p>
@@ -698,11 +370,13 @@
 <li><p>Empty tensor Empty tensor is illegal in SINGA.</p></li>
 </ul>
 <h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
-<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are three main
-class, <code>SingaFrontend</code> and <code>SingaBackend</code> and <code>SingaRep</code>. <code>SingaFrontend</code>
-translates a SINGA model to ONNX model; <code>SingaBackend</code> translates a ONNX model
-to <code>SingaRep</code> object which stores all SINGA operators and tensors(the tensor in
-this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run like a SINGA model.</p>
+<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are four main
+class, <code>SingaFrontend</code>, <code>SingaBackend</code>, <code>SingaRep</code> and <code>SONNXModel</code>.
+<code>SingaFrontend</code> translates a SINGA model to an ONNX model; <code>SingaBackend</code>
+translates an ONNX model to <code>SingaRep</code> object which stores all SINGA operators
+and tensors(the tensor in this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run
+like a SINGA model. <code>SONNXModel</code> inherits from <code>model.Model</code> which defines a
+unified API for SINGA.</p>
 <h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3>
 <p>The entry function of <code>SingaFrontend</code> is <code>singa_to_onnx_model</code> which also is
 called <code>to_onnx</code>. <code>singa_to_onnx_model</code> creates the ONNX model, and it also
@@ -724,40 +398,32 @@
 operators to be changed as bool type.</p>
 <h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3>
 <p>The entry function of <code>SingaBackend</code> is <code>prepare</code> which checks the version of
-ONNX model and call <code>_onnx_model_to_singa_net</code> then.</p>
-<p>The purpose of <code>_onnx_model_to_singa_net</code> is to get SINGA tensors and operators.
+ONNX model and call <code>_onnx_model_to_singa_ops</code> then.</p>
+<p>The purpose of <code>_onnx_model_to_singa_ops</code> is to get SINGA tensors and operators.
 The tensors are stored in a dictionary by their name in ONNX, and operators are
-stored in queue by the form of
-<code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>. For each
-operator, <code>name</code> is its ONNX node name; <code>op</code> is the ONNX node; <code>forward</code> is the
-SINGA operator's forward function; <code>handle</code> is prepared for some special
-operators such as Conv and Pooling which has <code>handle</code> object.</p>
-<p>The first step of <code>_onnx_model_to_singa_net</code> is to call <code>_init_graph_parameter</code>
-to get all tensors within the model. For trainable weights, it can init SINGA
-<code>Tensor</code> from <code>onnx_model.graph.initializer</code>. Please note, the weights may also
-be stored within graph's input or a ONNX node called <code>Constant</code>, SINGA can also
-handle these.</p>
-<p>Though all weights are stored within ONNX model, the input of the model is
-unknown but its shape and type. So SINGA support two ways to init input, 1,
-generate random tensor by its shape and type, 2, allow the user to assign the
-input. The first way works fine for most models, however, for some model such as
-bert, the indices of matrix cannot be random generated otherwise it will incurs
-errors.</p>
-<p>Then, <code>_onnx_model_to_singa_net</code> iterators all nodes within ONNX graph to
-translate it to SIGNA operators. Also, <code>_rename_operators</code> defines the operators
-name mapping between SINGA and ONNX. <code>_special_operators</code> defines which function
-to be used to translate the operator. <code>_run_node</code> runs the generated SINGA model
-by its input tensors and store its output tensors for being used by later
-operators.</p>
-<p>This class finally return a <code>SingaRep</code> object and stores all SINGA tensors and
-operators within it.</p>
+stored in queue by the form of <code>namedtuple('SingaOps', ['node', 'operator'])</code>.
+For each operator, <code>node</code> is an instance from OnnxNode which is defined to store
+some basic information for an ONNX node; <code>operator</code> is the SINGA operator's
+forward function;</p>
+<p>The first step of <code>_onnx_model_to_singa_ops</code> has four steps, the first one is to
+call <code>_parse_graph_params</code> to get all tensors stored as <code>params</code>. Then call
+<code>_parse_graph_inputs_outputs</code> to get all input and output information stores as
+<code>inputs</code> and <code>outputs</code>. Finally, it iterators all nodes within the ONNX graph
+and parses it by <code>_onnx_node_to_singa_op</code> as SIGNA operators or layers and store
+them as <code>outputs</code>. Some weights are stored within an ONNX node called
+<code>Constant</code>, SONNX can handle them by <code>_onnx_constant_to_np</code> to store it into
+<code>params</code>.</p>
+<p>This class finally return a <code>SingaRep</code> object and stores above <code>params</code>,
+<code>inputs</code>, <code>outputs</code>, <code>layers</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3>
 <p><code>SingaBackend</code> stores all SINGA tensors and operators. <code>run</code> accepts the input
-of the model and run the SINGA operators one by one following the operators
-queue. The user can use <code>last_layers</code> to decide to run the model till the last
-few layers. Set <code>all_outputs</code> as <code>False</code> to get only the final output, <code>True</code> to
-also get all the intermediate output.</p>
-</span></div></article></div><div class="docLastUpdate"><em>Last updated on 20/09/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#a-full-example">A Full Example</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
+of the model and runs the SINGA operators one by one following the operators'
+queue. The user can use <code>last_layers</code> to cut off the model after the last few
+layers.</p>
+<h3><a class="anchor" aria-hidden="true" id="sonnxmodel"></a><a href="#sonnxmodel" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SONNXModel</h3>
+<p><code>SONNXModel</code> inherits from <code>sonnx.SONNXModel</code> and implements the method
+<code>forward</code> to provide a unified API with other SINGA models.</p>
+</span></div></article></div><div class="docLastUpdate"><em>Last updated on 25/11/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li><li><a href="#sonnxmodel">SONNXModel</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
    The Apache Software Foundation. All rights reserved.
    Apache SINGA, Apache, the Apache feather logo, and
    the Apache SINGA project logos are trademarks of The
diff --git a/content/docs/onnx/index.html b/content/docs/onnx/index.html
index eeefc5a..5b7ba63 100644
--- a/content/docs/onnx/index.html
+++ b/content/docs/onnx/index.html
@@ -78,60 +78,47 @@
 </table>
 <h2><a class="anchor" aria-hidden="true" id="general-usage"></a><a href="#general-usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>General usage</h2>
 <h3><a class="anchor" aria-hidden="true" id="loading-an-onnx-model-into-singa"></a><a href="#loading-an-onnx-model-into-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Loading an ONNX Model into SINGA</h3>
-<p>After loading an ONNX model from disk by <code>onnx.load</code>, you need to update the
-model's batchsize, since for most models, they use a placeholder to represent
-its batchsize. We give an example here, as <code>update_batch_size</code>. You only need to
-update the batchsize of input and output, the shape of internal tensors will be
-inferred automatically.</p>
-<p>Then, you can prepare the SINGA model by using <code>sonnx.prepare</code>. This function
-iterates and translates all the nodes within the ONNX model's graph into SINGA
-operators, loads all stored weights and infers each intermediate tensor's shape.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx
-from singa <span class="hljs-built_in">import</span> device
-from singa <span class="hljs-built_in">import</span> sonnx
+<p>After loading an ONNX model from disk by <code>onnx.load</code>, You only need to update
+the batch-size of input using <code>tensor.PlaceHolder</code> after SINGA v3.0, the shape
+of internal tensors will be inferred automatically.</p>
+<p>Then, you should define a class inheriting from <code>sonnx.SONNXModel</code> and implement
+two methods <code>forward</code> for forward work and <code>train_one_batch</code> for training work.
+After you call <code>model.compile</code>, the SONNX iterates and translates all the nodes
+within the ONNX model's graph into SINGA operators, loads all stored weights and
+infers each intermediate tensor's shape.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-keyword">import</span> onnx
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
+<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
 
-<span class="hljs-comment"># if the input has multiple tensors? can put this function inside prepare()?</span>
-def update_batch_size(onnx_model, batch_size):
-    <span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>]
-    model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    <span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>]
-    model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
-    return onnx_model
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span><span class="hljs-params">(sonnx.SONNXModel)</span>:</span>
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, onnx_model)</span>:</span>
+        super(MyModel, self).__init__(onnx_model)
 
-<span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
-<span class="hljs-attr">onnx_model</span> = onnx.load(model_path)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, *x)</span>:</span>
+        y = super(MyModel, self).forward(*x)
+        <span class="hljs-comment"># Since SINGA model returns the output as a list,</span>
+        <span class="hljs-comment"># if there is only one output,</span>
+        <span class="hljs-comment"># you just need to take the first element.</span>
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-<span class="hljs-comment"># set batch size</span>
-<span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(self, x, y)</span>:</span>
+        <span class="hljs-keyword">pass</span>
 
-<span class="hljs-comment"># convert onnx graph nodes into SINGA operators</span>
-<span class="hljs-attr">dev</span> = device.create_cuda_gpu()
-<span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span>
+model_path = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
+onnx_model = onnx.load(model_path)
+
+<span class="hljs-comment"># convert onnx model into SINGA model</span>
+dev = device.create_cuda_gpu()
+x = tensor.PlaceHolder(INPUT.shape, device=dev)
+model = MyModel(onnx_model)
+model.compile([x], is_train=<span class="hljs-literal">False</span>, use_graph=<span class="hljs-literal">True</span>, sequential=<span class="hljs-literal">True</span>)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="inference-singa-model"></a><a href="#inference-singa-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference SINGA model</h3>
-<p>Once the model is created, you can do inference by calling <code>sg_ir.run</code>. The
-input and output must be SINGA <code>Tensor</code> instances. Since SINGA model returns the
-output as a list, if there is only one output, you just need to take the first
-element from the output.</p>
-<pre><code class="hljs css language-python3"><span class="hljs-comment"># can warp the following code in prepare()</span>
-<span class="hljs-comment"># and provide a flag training=True/False?</span>
-
-<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
-        <span class="hljs-keyword">self</span>.sg_ir = sg_ir
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
-
-
-data = get_dataset()
-x = tensor.Tensor(device=dev, data=data)
-
-model = Infer(sg_ir)
-y = model.forward(x)
+<p>Once the model is created, you can do inference by calling <code>model.forward</code>. The
+input and output must be SINGA <code>Tensor</code> instances.</p>
+<pre><code class="hljs css language-python3"><span class="hljs-attr">x</span> = tensor.Tensor(device=dev, data=INPUT)
+<span class="hljs-attr">y</span> = model.forward(x)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="saving-singa-model-into-onnx-format"></a><a href="#saving-singa-model-into-onnx-format" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Saving SINGA model into ONNX Format</h3>
 <p>Given the input tensors and the output tensors generated by the operators the
@@ -142,410 +129,93 @@
 sonnx.to_onnx([<span class="hljs-symbol">x</span>], [<span class="hljs-symbol">y</span>])
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="re-training-an-onnx-model"></a><a href="#re-training-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training an ONNX model</h3>
-<p>To train (or refine) an ONNX model using SINGA, you need to set the internal
-tensors to be trainable</p>
-<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
+<p>To train (or refine) an ONNX model using SINGA, you need to implement the
+<code>train_one_batch</code> from <code>sonnx.SONNXModel</code> and mark the <code>is_train=True</code> when
+calling <code>model.compile</code>.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-comment">## can wrap these codes in sonnx?</span>
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad = <span class="hljs-literal">True</span>
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
 
-autograd.training = <span class="hljs-literal">False</span>
-model = Infer(sg_ir)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x)
+        <span class="hljs-keyword">return</span> y[<span class="hljs-number">0</span>]
 
-autograd.training = <span class="hljs-literal">True</span>
-<span class="hljs-comment"># then you training the model like normal</span>
-<span class="hljs-comment"># give more details??</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
+
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
+
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h3><a class="anchor" aria-hidden="true" id="transfer-learning-an-onnx-model"></a><a href="#transfer-learning-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer-learning an ONNX model</h3>
-<p>You also can append some layers to the end of ONNX model to do
-transfer-learning. The <code>last_layers</code> means you cut the ONNX layers from [0,
-last_layers]. Then you can append more layers by the normal SINGA model.</p>
-<pre><code class="hljs css language-python3">class Trans:
+<p>You also can append some layers to the end of the ONNX model to do
+transfer-learning. The <code>last_layers</code> accept a negative integer indicating the
+layer to cut off from. For example, <code>-1</code> means cut off after the final output(do
+not cut off any layer), <code>-2</code> means you cut off after the last second layer.</p>
+<pre><code class="hljs css language-python3">from singa import opt
+from singa import autograd
 
-    def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers):
-        <span class="hljs-literal">self</span>.sg_ir = sg_ir
-        <span class="hljs-literal">self</span>.last_layers = last_layers
-        <span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False)
-        <span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False)
+<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyModel</span>(<span class="hljs-title">sonnx</span>.<span class="hljs-title">SONNXModel</span>):</span>
 
-    def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>):
-        <span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>]
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>)
-        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
-        <span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span>
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, onnx_model)</span></span>:
+        <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).__init_<span class="hljs-number">_</span>(onnx_model)
+        <span class="hljs-keyword">self</span>.linear = layer.Linear(<span class="hljs-number">1000</span>, <span class="hljs-number">3</span>)
 
-autograd.training = False
-model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-# <span class="hljs-keyword">then</span> you training the model like normal
-</code></pre>
-<h2><a class="anchor" aria-hidden="true" id="a-full-example"></a><a href="#a-full-example" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>A Full Example</h2>
-<p>This part introduces the usage of SINGA ONNX by using the mnist example. In this
-section, the examples of how to export, load, inference, re-training, and
-transfer-learning the minist model are displayed. You can try this part
-<a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">here</a>.</p>
-<h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3>
-<p>Firstly, you need to import some necessary libraries and define some auxiliary
-functions for downloading and preprocessing the dataset:</p>
-<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
-<span class="hljs-keyword">import</span> urllib.request
-<span class="hljs-keyword">import</span> gzip
-<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
-<span class="hljs-keyword">import</span> codecs
-
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
-<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
-<span class="hljs-keyword">import</span> onnx
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
-    train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
-    train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
-    valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
-    valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
-    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
-        np.float32)
-    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
-        np.float32)
-    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
-        np.float32)
-    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
-        np.float32)
-    <span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>
-
-    download_dir = <span class="hljs-string">'/tmp/'</span>
-
-    name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
-    filename = os.path.join(download_dir, name)
-    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
-        print(<span class="hljs-string">"Downloading %s"</span> % url)
-        urllib.request.urlretrieve(url, filename)
-    <span class="hljs-keyword">return</span> filename
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
-            (length))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
-    <span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
-    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
-        data = f.read()
-        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
-        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
-        num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
-        num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
-            (length, <span class="hljs-number">1</span>, num_rows, num_cols))
-        <span class="hljs-keyword">return</span> parsed
-
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
-    y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
-    n = y.shape[<span class="hljs-number">0</span>]
-    categorical = np.zeros((n, num_classes))
-    categorical[np.arange(n), y] = <span class="hljs-number">1</span>
-    categorical = categorical.astype(np.float32)
-    <span class="hljs-keyword">return</span> categorical
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3>
-<p>Then you can define a class called <strong>CNN</strong> to construct the mnist model which
-consists of several convolution, pooling, fully connection and relu layers. You
-can also define a function to calculate the <strong>accuracy</strong> of our result. Finally,
-you can define a <strong>train</strong> and a <strong>test</strong> function to handle the training and
-prediction process.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
-        self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
-        self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
-        self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-        self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-        self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = self.conv1(x)
-        y = autograd.relu(y)
-        y = self.pooling1(y)
-        y = self.conv2(y)
-        y = autograd.relu(y)
-        y = self.pooling2(y)
-        y = autograd.flatten(y)
-        y = self.linear1(y)
-        y = autograd.relu(y)
-        y = self.linear2(y)
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, *x)</span></span>:
+        <span class="hljs-comment"># cut off after the last third layer</span>
+        <span class="hljs-comment"># and append a linear layer</span>
+        y = <span class="hljs-keyword">super</span>(MyModel, <span class="hljs-keyword">self</span>).forward(*x, last_layers=-<span class="hljs-number">3</span>)[<span class="hljs-number">0</span>]
+        y = <span class="hljs-keyword">self</span>.linear(y)
         <span class="hljs-keyword">return</span> y
 
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x, y, dist_option, spars)</span></span>:
+        out = <span class="hljs-keyword">self</span>.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        <span class="hljs-keyword">if</span> dist_option == <span class="hljs-string">'fp32'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update(loss)
+        elif dist_option == <span class="hljs-string">'fp16'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_update_half(loss)
+        elif dist_option == <span class="hljs-string">'partialUpdate'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_partial_update(loss)
+        elif dist_option == <span class="hljs-string">'sparseTopK'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == <span class="hljs-string">'sparseThreshold'</span>:
+            <span class="hljs-keyword">self</span>.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        <span class="hljs-keyword">return</span> out, loss
 
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
-    y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
-    t = np.argmax(target, axis=<span class="hljs-number">1</span>)
-    a = y == t
-    <span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))
+    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, optimizer)</span></span>:
+        <span class="hljs-keyword">self</span>.optimizer = optimizer
 
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
-          x,
-          y,
-          epochs=<span class="hljs-number">1</span>,
-          batch_size=<span class="hljs-number">64</span>,
-          dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = model.forward(x_batch)
-            <span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
-            <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"training completed"</span>)
-    <span class="hljs-keyword">return</span> x_batch, output_batch
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    result = <span class="hljs-number">0</span>
-    <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-        l_idx = b * batch_size
-        r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-        output_batch = model.forward(x_batch)
-        result += accuracy(tensor.to_numpy(output_batch),
-                           tensor.to_numpy(target_batch))
-
-    print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3>
-<p>Now, you can train the mnist model and export its onnx model by calling the
-<strong>soonx.to_onnx</strong> function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
-    <span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])
-
-<span class="hljs-comment"># create device</span>
-dev = device.create_cuda_gpu()
-<span class="hljs-comment">#dev = device.get_default_device()</span>
-<span class="hljs-comment"># create model</span>
-model = CNN()
-<span class="hljs-comment"># load data</span>
-train_x, train_y, valid_x, valid_y = load_dataset()
-<span class="hljs-comment"># normalization</span>
-train_x = train_x / <span class="hljs-number">255</span>
-valid_x = valid_x / <span class="hljs-number">255</span>
-train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
-valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
-<span class="hljs-comment"># do training</span>
-autograd.training = <span class="hljs-literal">True</span>
-x, y = train(model, train_x, train_y, dev=dev)
-onnx_model = make_onnx(x, y)
-<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
-
-<span class="hljs-comment"># Save the ONNX model</span>
-model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
-onnx.save(onnx_model, model_path)
-print(<span class="hljs-string">'The model is saved.'</span>)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
-<p>After you export the onnx model, you can find a file called <strong>mnist.onnx</strong> in
-the '/tmp' directory, this model, therefore, can be imported by other libraries.
-Now, if you want to import this onnx model into singa again and do the inference
-using the validation dataset, you can define a class called <strong>Infer</strong>, the
-forward function of Infer will be called by the test function to do inference
-for validation dataset. By the way, you should set the label of training to
-<strong>False</strong> to fix the gradient of autograd operators.</p>
-<p>When import the onnx model, you need to call <strong>onnx.load</strong> to load the onnx
-model firstly. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to
-parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a
-singa graph within it, and then you can run an step of inference by feeding
-input to its run function.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
-        self.sg_ir = sg_ir
-        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
-            <span class="hljs-comment"># allow the tensors to be updated</span>
-            tens.requires_grad = <span class="hljs-literal">True</span>
-            tens.stores_grad= <span class="hljs-literal">True</span>
-            sg_ir.tensor_map[idx] = tens
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>
-
-<span class="hljs-comment"># inference</span>
-autograd.training = <span class="hljs-literal">False</span>
-print(<span class="hljs-string">'The inference result is:'</span>)
-test(Infer(sg_ir), valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3>
-<p>Assume after import the model, you want to re-train the model again, we can
-define a function called <strong>re_train</strong>. Before we call this re_train function, we
-should set the label of training to <strong>True</strong> to make the autograde operators
-update their gradient. And after we finish the training, we set it as <strong>False</strong>
-again to call the test function doing inference.</p>
-<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    new_model = Infer(sg_ir)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = new_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"re-training completed"</span>)
-    <span class="hljs-keyword">return</span> new_model
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># re-training</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = re_train(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
-</code></pre>
-<h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3>
-<p>Finally, if we want to do transfer-learning, we can define a function called
-<strong>Trans</strong> to append some layers after the onnx model. For demonstration, the
-code only appends several linear(fully connection) and relu after the onnx
-model. You can define a transfer_learning function to handle the training
-process of the transfer-learning model. And the label of training is the same as
-the previous one.</p>
-<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
-        self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
-
-    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
-        y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        <span class="hljs-keyword">return</span> y
-
-<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
-             x,
-             y,
-             epochs=<span class="hljs-number">1</span>,
-             batch_size=<span class="hljs-number">64</span>,
-             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
-    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
-
-    trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)
-
-    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
-        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + <span class="hljs-number">1</span>) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-            output_batch = trans_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
-            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
-                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
-                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
-    print(<span class="hljs-string">"transfer-learning completed"</span>)
-    <span class="hljs-keyword">return</span> trans_mode
-
-<span class="hljs-comment"># load the ONNX model</span>
-onnx_model = onnx.load(model_path)
-sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-<span class="hljs-comment"># transfer-learning</span>
-autograd.training = <span class="hljs-literal">True</span>
-new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
-autograd.training = <span class="hljs-literal">False</span>
-test(new_model, valid_x, valid_y, dev=dev)
+sgd = opt.SGD(lr=<span class="hljs-number">0</span>.<span class="hljs-number">005</span>, momentum=<span class="hljs-number">0</span>.<span class="hljs-number">9</span>, weight_decay=<span class="hljs-number">1</span>e-<span class="hljs-number">5</span>)
+model.set_optimizer(sgd)
+model.compile([tx], is_train=True, use_graph=graph, sequential=True)
 </code></pre>
 <h2><a class="anchor" aria-hidden="true" id="onnx-model-zoo"></a><a href="#onnx-model-zoo" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX model zoo</h2>
 <p>The <a href="https://github.com/onnx/models">ONNX Model Zoo</a> is a collection of
@@ -567,6 +237,8 @@
 <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/shufflenet">ShuffleNet_V2</a></b></td><td><a href="https://arxiv.org/pdf/1707.01083.pdf">Simonyan et al.</a></td><td>Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6%</td><td>[<img src="https://colab.research.google.com/drive/19HfRu3YHP_H2z3BcZujVFRp23_J5XsuA?usp=sharing" alt="Open In Colab"></td></tr>
 </tbody>
 </table>
+<p>We also give some re-training examples by using VGG and ResNet, please check
+<code>examples/onnx/training</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="object-detection"></a><a href="#object-detection" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Object Detection</h3>
 <p>Object detection models detect the presence of multiple objects in an image and
 segment out areas of the image where the objects are detected.</p>
@@ -698,11 +370,13 @@
 <li><p>Empty tensor Empty tensor is illegal in SINGA.</p></li>
 </ul>
 <h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
-<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are three main
-class, <code>SingaFrontend</code> and <code>SingaBackend</code> and <code>SingaRep</code>. <code>SingaFrontend</code>
-translates a SINGA model to ONNX model; <code>SingaBackend</code> translates a ONNX model
-to <code>SingaRep</code> object which stores all SINGA operators and tensors(the tensor in
-this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run like a SINGA model.</p>
+<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are four main
+class, <code>SingaFrontend</code>, <code>SingaBackend</code>, <code>SingaRep</code> and <code>SONNXModel</code>.
+<code>SingaFrontend</code> translates a SINGA model to an ONNX model; <code>SingaBackend</code>
+translates an ONNX model to <code>SingaRep</code> object which stores all SINGA operators
+and tensors(the tensor in this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run
+like a SINGA model. <code>SONNXModel</code> inherits from <code>model.Model</code> which defines a
+unified API for SINGA.</p>
 <h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3>
 <p>The entry function of <code>SingaFrontend</code> is <code>singa_to_onnx_model</code> which also is
 called <code>to_onnx</code>. <code>singa_to_onnx_model</code> creates the ONNX model, and it also
@@ -724,40 +398,32 @@
 operators to be changed as bool type.</p>
 <h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3>
 <p>The entry function of <code>SingaBackend</code> is <code>prepare</code> which checks the version of
-ONNX model and call <code>_onnx_model_to_singa_net</code> then.</p>
-<p>The purpose of <code>_onnx_model_to_singa_net</code> is to get SINGA tensors and operators.
+ONNX model and call <code>_onnx_model_to_singa_ops</code> then.</p>
+<p>The purpose of <code>_onnx_model_to_singa_ops</code> is to get SINGA tensors and operators.
 The tensors are stored in a dictionary by their name in ONNX, and operators are
-stored in queue by the form of
-<code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>. For each
-operator, <code>name</code> is its ONNX node name; <code>op</code> is the ONNX node; <code>forward</code> is the
-SINGA operator's forward function; <code>handle</code> is prepared for some special
-operators such as Conv and Pooling which has <code>handle</code> object.</p>
-<p>The first step of <code>_onnx_model_to_singa_net</code> is to call <code>_init_graph_parameter</code>
-to get all tensors within the model. For trainable weights, it can init SINGA
-<code>Tensor</code> from <code>onnx_model.graph.initializer</code>. Please note, the weights may also
-be stored within graph's input or a ONNX node called <code>Constant</code>, SINGA can also
-handle these.</p>
-<p>Though all weights are stored within ONNX model, the input of the model is
-unknown but its shape and type. So SINGA support two ways to init input, 1,
-generate random tensor by its shape and type, 2, allow the user to assign the
-input. The first way works fine for most models, however, for some model such as
-bert, the indices of matrix cannot be random generated otherwise it will incurs
-errors.</p>
-<p>Then, <code>_onnx_model_to_singa_net</code> iterators all nodes within ONNX graph to
-translate it to SIGNA operators. Also, <code>_rename_operators</code> defines the operators
-name mapping between SINGA and ONNX. <code>_special_operators</code> defines which function
-to be used to translate the operator. <code>_run_node</code> runs the generated SINGA model
-by its input tensors and store its output tensors for being used by later
-operators.</p>
-<p>This class finally return a <code>SingaRep</code> object and stores all SINGA tensors and
-operators within it.</p>
+stored in queue by the form of <code>namedtuple('SingaOps', ['node', 'operator'])</code>.
+For each operator, <code>node</code> is an instance from OnnxNode which is defined to store
+some basic information for an ONNX node; <code>operator</code> is the SINGA operator's
+forward function;</p>
+<p>The first step of <code>_onnx_model_to_singa_ops</code> has four steps, the first one is to
+call <code>_parse_graph_params</code> to get all tensors stored as <code>params</code>. Then call
+<code>_parse_graph_inputs_outputs</code> to get all input and output information stores as
+<code>inputs</code> and <code>outputs</code>. Finally, it iterators all nodes within the ONNX graph
+and parses it by <code>_onnx_node_to_singa_op</code> as SIGNA operators or layers and store
+them as <code>outputs</code>. Some weights are stored within an ONNX node called
+<code>Constant</code>, SONNX can handle them by <code>_onnx_constant_to_np</code> to store it into
+<code>params</code>.</p>
+<p>This class finally return a <code>SingaRep</code> object and stores above <code>params</code>,
+<code>inputs</code>, <code>outputs</code>, <code>layers</code>.</p>
 <h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3>
 <p><code>SingaBackend</code> stores all SINGA tensors and operators. <code>run</code> accepts the input
-of the model and run the SINGA operators one by one following the operators
-queue. The user can use <code>last_layers</code> to decide to run the model till the last
-few layers. Set <code>all_outputs</code> as <code>False</code> to get only the final output, <code>True</code> to
-also get all the intermediate output.</p>
-</span></div></article></div><div class="docLastUpdate"><em>Last updated on 20/09/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#a-full-example">A Full Example</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
+of the model and runs the SINGA operators one by one following the operators'
+queue. The user can use <code>last_layers</code> to cut off the model after the last few
+layers.</p>
+<h3><a class="anchor" aria-hidden="true" id="sonnxmodel"></a><a href="#sonnxmodel" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SONNXModel</h3>
+<p><code>SONNXModel</code> inherits from <code>sonnx.SONNXModel</code> and implements the method
+<code>forward</code> to provide a unified API with other SINGA models.</p>
+</span></div></article></div><div class="docLastUpdate"><em>Last updated on 25/11/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li><li><a href="#sonnxmodel">SONNXModel</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
    The Apache Software Foundation. All rights reserved.
    Apache SINGA, Apache, the Apache feather logo, and
    the Apache SINGA project logos are trademarks of The