blob: e8c59287bc44f23c9656609276ba07402630c66f [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>ONNX · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="3.0.0.rc1"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="ONNX · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://feynmandna.github.io/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://feynmandna.github.io/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://feynmandna.github.io/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://feynmandna.github.io/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://feynmandna.github.io/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>3.0.0.rc1</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class=""><a href="/docs/installation" target="_self">Docs</a></li><li class=""><a href="/docs/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class=""><a target="_self"></a></li><li class=""><a href="https://github.com/apache/singa-doc" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs/onnx.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">ONNX</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>ONNX is an open format built to represent machine learning models, which enables
an ability to transfer trained models between different deep learning
frameworks. We have integrated the main functionality of ONNX into SINGA, and
several basic operators have been supported. More operators are being
developing.</p>
<p>The supported [ONNX
version}(<a href="https://github.com/onnx/onnx/blob/master/docs/Versioning.md">https://github.com/onnx/onnx/blob/master/docs/Versioning.md</a>) os SINGA
is:</p>
<table>
<thead>
<tr><th>ONNX version</th><th>File format version</th><th>Opset version ai.onnx</th><th>Opset version ai.onnx.ml</th><th>Opset version ai.onnx.training</th></tr>
</thead>
<tbody>
<tr><td>1.6.0</td><td>6</td><td>11</td><td>2</td><td>-</td></tr>
</tbody>
</table>
<h2><a class="anchor" aria-hidden="true" id="general-usage"></a><a href="#general-usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>General usage</h2>
<p>The onnx in SINGA has supported the basic functionality, and please refer the
following tutorials for general usage:</p>
<h3><a class="anchor" aria-hidden="true" id="loading-an-onnx-model-into-singa"></a><a href="#loading-an-onnx-model-into-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Loading an ONNX Model into SINGA</h3>
<p>This part introduces how to import and prepare a SINGA model from a ONNX model.
After you load a ONNX model by <code>onnx.load</code>, you need to update the model's
batchsize, since for most model, they uses a placeholder to represent its
batchsize. We give an example here, as <code>update_batch_size</code>. You only needs to
update the batchsize of input and output, the shape of inner tensor will be
inferred automatically.</p>
<p>Then, you can prepare the SINGA model by using <code>sonnx.prepare</code>. This function
iteraters and translates all the nodes within the ONNX model's graph to SINGA
operators, loads all stored weights and infers each intermediate tensor's shape.
For the device used, please refer to the <code>device</code> section.</p>
<pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx
from singa <span class="hljs-built_in">import</span> device
from singa <span class="hljs-built_in">import</span> sonnx
def update_batch_size(onnx_model, batch_size):
<span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>]
model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
<span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>]
model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
return onnx_model
<span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
<span class="hljs-attr">onnx_model</span> = onnx.load(model_path)
<span class="hljs-comment"># set batch size</span>
<span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>)
<span class="hljs-comment"># prepare the model</span>
<span class="hljs-attr">dev</span> = device.create_cuda_gpu()
<span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span>
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="inferernce-singa-model"></a><a href="#inferernce-singa-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inferernce SINGA model</h3>
<p>After you load and prepare a SINGA model, you can do the inference by calling
<code>sg_ir.run</code> as the following code. The input and output must be SINGA <code>Tensor</code>,
and since SINGA model returns the output as a list, so if you only have one
output, you just take the first element from the output as <code>forward</code> of <code>Infer</code>
class.</p>
<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
<span class="hljs-keyword">self</span>.sg_ir = sg_ir
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
<span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
data = get_dataset()
x = tensor.Tensor(device=dev, data=data)
model = Infer(sg_ir)
y = model.forward(x)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="saving-an-onnx-model-from-singa"></a><a href="#saving-an-onnx-model-from-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Saving an ONNX Model from SINGA</h3>
<p>Now, if you have a SINGA model, you can export it as ONNX model as following:</p>
<pre><code class="hljs css language-python3">sonnx.<span class="hljs-keyword">to</span><span class="hljs-constructor">_onnx([<span class="hljs-params">x</span>], [<span class="hljs-params">y</span>])</span>
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="re-training-a-onnx-model"></a><a href="#re-training-a-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training a ONNX model</h3>
<p>You also can re-training a ONNX model after you load it into SINGA as following
code. Please node you should set all tensors of the SINGA model to enable them
to store gradient by <code>tens.requires_grad = True</code> and <code>tens.stores_grad = True</code>.</p>
<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
<span class="hljs-keyword">self</span>.sg_ir = sg_ir
<span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
<span class="hljs-comment"># allow the tensors to be updated</span>
tens.requires_grad = True
tens.stores_grad = True
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
<span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]
autograd.training = False
model = Infer(sg_ir)
<span class="hljs-comment"># then you training the model like normal</span>
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="transfer-learning-a-onnx-model"></a><a href="#transfer-learning-a-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer-learning a ONNX model</h3>
<p>You also can append some layers to the end of ONNX model to do transfer-learning
like following. The <code>last_layers</code> means you cut the ONNX layers from [0,
last_layers]. Then you can append more layers by the normal SINGA model.</p>
<pre><code class="hljs css language-python3">class Trans:
def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers):
<span class="hljs-literal">self</span>.sg_ir = sg_ir
<span class="hljs-literal">self</span>.last_layers = last_layers
<span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False)
<span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False)
<span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False)
def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>):
<span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>]
<span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>)
<span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
<span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>)
<span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
<span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>)
<span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
<span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span>
autograd.training = False
model = Trans(sg_ir, <span class="hljs-number">-1</span>)
# <span class="hljs-keyword">then</span> you training the model like normal
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="example-onnx-mnist-on-singa"></a><a href="#example-onnx-mnist-on-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Example: ONNX mnist on SINGA</h2>
<p>This part introduces the usage of SINGA ONNX by using the mnist example. In this
section, the examples of how to export, load, inference, re-training, and
transfer-learning the minist model are displayed. You can try this part
<a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">here</a>.</p>
<h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3>
<p>Firstly, you need to import some necessary libraries and define some auxiliary
functions for downloading and preprocessing the dataset:</p>
<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> urllib.request
<span class="hljs-keyword">import</span> gzip
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> codecs
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
<span class="hljs-keyword">import</span> onnx
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
np.float32)
train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
np.float32)
valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
np.float32)
valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
np.float32)
<span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>
download_dir = <span class="hljs-string">'/tmp/'</span>
name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
filename = os.path.join(download_dir, name)
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
print(<span class="hljs-string">"Downloading %s"</span> % url)
urllib.request.urlretrieve(url, filename)
<span class="hljs-keyword">return</span> filename
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
<span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
data = f.read()
<span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
(length))
<span class="hljs-keyword">return</span> parsed
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
<span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
<span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
data = f.read()
<span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
(length, <span class="hljs-number">1</span>, num_rows, num_cols))
<span class="hljs-keyword">return</span> parsed
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
n = y.shape[<span class="hljs-number">0</span>]
categorical = np.zeros((n, num_classes))
categorical[np.arange(n), y] = <span class="hljs-number">1</span>
categorical = categorical.astype(np.float32)
<span class="hljs-keyword">return</span> categorical
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3>
<p>Then you can define a class called <strong>CNN</strong> to construct the mnist model which
consists of several convolution, pooling, fully connection and relu layers. You
can also define a function to calculate the <strong>accuracy</strong> of our result. Finally,
you can define a <strong>train</strong> and a <strong>test</strong> function to handle the training and
prediction process.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
y = self.conv1(x)
y = autograd.relu(y)
y = self.pooling1(y)
y = self.conv2(y)
y = autograd.relu(y)
y = self.pooling2(y)
y = autograd.flatten(y)
y = self.linear1(y)
y = autograd.relu(y)
y = self.linear2(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
t = np.argmax(target, axis=<span class="hljs-number">1</span>)
a = y == t
<span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
x,
y,
epochs=<span class="hljs-number">1</span>,
batch_size=<span class="hljs-number">64</span>,
dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = model.forward(x_batch)
<span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
loss = autograd.softmax_cross_entropy(output_batch, target_batch)
accuracy_rate = accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, gp)
sgd.step()
<span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
(accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
print(<span class="hljs-string">"training completed"</span>)
<span class="hljs-keyword">return</span> x_batch, output_batch
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
result = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = model.forward(x_batch)
result += accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3>
<p>Now, you can train the mnist model and export its onnx model by calling the
<strong>soonx.to_onnx</strong> function.</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
<span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])
<span class="hljs-comment"># create device</span>
dev = device.create_cuda_gpu()
<span class="hljs-comment">#dev = device.get_default_device()</span>
<span class="hljs-comment"># create model</span>
model = CNN()
<span class="hljs-comment"># load data</span>
train_x, train_y, valid_x, valid_y = load_dataset()
<span class="hljs-comment"># normalization</span>
train_x = train_x / <span class="hljs-number">255</span>
valid_x = valid_x / <span class="hljs-number">255</span>
train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
<span class="hljs-comment"># do training</span>
autograd.training = <span class="hljs-literal">True</span>
x, y = train(model, train_x, train_y, dev=dev)
onnx_model = make_onnx(x, y)
<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
<span class="hljs-comment"># Save the ONNX model</span>
model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
onnx.save(onnx_model, model_path)
print(<span class="hljs-string">'The model is saved.'</span>)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
<p>After you export the onnx model, you can find a file called <strong>mnist.onnx</strong> in
the '/tmp' directory, this model, therefore, can be imported by other libraries.
Now, if you want to import this onnx model into singa again and do the inference
using the validation dataset, you can define a class called <strong>Infer</strong>, the
forward function of Infer will be called by the test function to do inference
for validation dataset. By the way, you should set the label of training to
<strong>False</strong> to fix the gradient of autograd operators.</p>
<p>When import the onnx model, you need to call <strong>onnx.load</strong> to load the onnx
model firstly. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to
parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a
singa graph within it, and then you can run an step of inference by feeding
input to its run function.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
self.sg_ir = sg_ir
<span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
<span class="hljs-comment"># allow the tensors to be updated</span>
tens.requires_grad = <span class="hljs-literal">True</span>
tens.stores_grad= <span class="hljs-literal">True</span>
sg_ir.tensor_map[idx] = tens
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
<span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>
<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>
<span class="hljs-comment"># inference</span>
autograd.training = <span class="hljs-literal">False</span>
print(<span class="hljs-string">'The inference result is:'</span>)
test(Infer(sg_ir), valid_x, valid_y, dev=dev)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3>
<p>Assume after import the model, you want to re-train the model again, we can
define a function called <strong>re_train</strong>. Before we call this re_train function, we
should set the label of training to <strong>True</strong> to make the autograde operators
update their gradient. And after we finish the training, we set it as <strong>False</strong>
again to call the test function doing inference.</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
x,
y,
epochs=<span class="hljs-number">1</span>,
batch_size=<span class="hljs-number">64</span>,
dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
new_model = Infer(sg_ir)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = new_model.forward(x_batch)
loss = autograd.softmax_cross_entropy(output_batch, target_batch)
accuracy_rate = accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, gp)
sgd.step()
<span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
(accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
print(<span class="hljs-string">"re-training completed"</span>)
<span class="hljs-keyword">return</span> new_model
<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev)
<span class="hljs-comment"># re-training</span>
autograd.training = <span class="hljs-literal">True</span>
new_model = re_train(sg_ir, train_x, train_y, dev=dev)
autograd.training = <span class="hljs-literal">False</span>
test(new_model, valid_x, valid_y, dev=dev)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3>
<p>Finally, if we want to do transfer-learning, we can define a function called
<strong>Trans</strong> to append some layers after the onnx model. For demonstration, the
code only appends several linear(fully connection) and relu after the onnx
model. You can define a transfer_learning function to handle the training
process of the transfer-learning model. And the label of training is the same as
the previous one.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
self.sg_ir = sg_ir
self.last_layers = last_layers
self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
y = self.append_linear1(y)
y = autograd.relu(y)
y = self.append_linear2(y)
y = autograd.relu(y)
y = self.append_linear3(y)
y = autograd.relu(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
x,
y,
epochs=<span class="hljs-number">1</span>,
batch_size=<span class="hljs-number">64</span>,
dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = trans_model.forward(x_batch)
loss = autograd.softmax_cross_entropy(output_batch, target_batch)
accuracy_rate = accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, gp)
sgd.step()
<span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
(accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
print(<span class="hljs-string">"transfer-learning completed"</span>)
<span class="hljs-keyword">return</span> trans_mode
<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev)
<span class="hljs-comment"># transfer-learning</span>
autograd.training = <span class="hljs-literal">True</span>
new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
autograd.training = <span class="hljs-literal">False</span>
test(new_model, valid_x, valid_y, dev=dev)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="onnx-model-zoo"></a><a href="#onnx-model-zoo" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX model zoo</h2>
<p>The <a href="https://github.com/onnx/models">ONNX Model Zoo</a> is a collection of
pre-trained, state-of-the-art models in the ONNX format contributed by community
members. SINGA has supported several CV and NLP models now. More models are
going to be supported soon.</p>
<h3><a class="anchor" aria-hidden="true" id="image-classification"></a><a href="#image-classification" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Image Classification</h3>
<p>This collection of models take images as input, then classifies the major
objects in the images into 1000 object categories such as keyboard, mouse,
pencil, and many animals.</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/mobilenet">MobileNet</a></b></td><td><a href="https://arxiv.org/abs/1801.04381">Sandler et al.</a></td><td>Light-weight deep neural network best suited for mobile and embedded vision applications. <br>Top-5 error from paper - ~10%</td><td><a href="https://colab.research.google.com/drive/1HsixqJMIpKyEPhkbB8jy7NwNEFEAUWAf"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/resnet">ResNet18</a></b></td><td><a href="https://arxiv.org/abs/1512.03385">He et al.</a></td><td>A CNN model (up to 152 layers). Uses shortcut connections to achieve higher accuracy when classifying images. <br> Top-5 error from paper - ~3.6%</td><td><a href="https://colab.research.google.com/drive/1u1RYefSsVbiP4I-5wiBKHjsT9L0FxLm9"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/vgg">VGG16</a></b></td><td><a href="https://arxiv.org/abs/1409.1556">Simonyan et al.</a></td><td>Deep CNN model(up to 19 layers). Similar to AlexNet but uses multiple smaller kernel-sized filters that provides more accuracy when classifying images. <br>Top-5 error from paper - ~8%</td><td><a href="https://colab.research.google.com/drive/14kxgRKtbjPCKKsDJVNi3AvTev81Gp_Ds"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h3><a class="anchor" aria-hidden="true" id="object-detection"></a><a href="#object-detection" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Object Detection</h3>
<p>Object detection models detect the presence of multiple objects in an image and
segment out areas of the image where the objects are detected.</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny_yolov2">Tiny YOLOv2</a></b></td><td><a href="https://arxiv.org/pdf/1612.08242.pdf">Redmon et al.</a></td><td>A real-time CNN for object detection that detects 20 different classes. A smaller version of the more complex full YOLOv2 network.</td><td><a href="https://colab.research.google.com/drive/11V4I6cRjIJNUv5ZGsEGwqHuoQEie6b1T"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h3><a class="anchor" aria-hidden="true" id="face-analysis"></a><a href="#face-analysis" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Face Analysis</h3>
<p>Face detection models identify and/or recognize human faces and emotions in
given images.</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/body_analysis/arcface">ArcFace</a></b></td><td><a href="https://arxiv.org/abs/1801.07698">Deng et al.</a></td><td>A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for input face images.</td><td><a href="https://colab.research.google.com/drive/1qanaqUKGIDtifdzEzJOHjEj4kYzA9uJC"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/body_analysis/emotion_ferplus">Emotion FerPlus</a></b></td><td><a href="https://arxiv.org/abs/1608.01041">Barsoum et al.</a></td><td>Deep CNN for emotion recognition trained on images of faces.</td><td><a href="https://colab.research.google.com/drive/1XHtBQGRhe58PDi4LGYJzYueWBeWbO23r"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h3><a class="anchor" aria-hidden="true" id="machine-comprehension"></a><a href="#machine-comprehension" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Machine Comprehension</h3>
<p>This subset of natural language processing models that answer questions about a
given context paragraph.</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad">BERT-Squad</a></b></td><td><a href="https://arxiv.org/pdf/1810.04805.pdf">Devlin et al.</a></td><td>This model answers questions based on the context of the given input paragraph.</td><td><a href="https://colab.research.google.com/drive/1kud-lUPjS_u-TkDAzihBTw0Vqr0FjCE-"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h2><a class="anchor" aria-hidden="true" id="supported-operators"></a><a href="#supported-operators" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Supported operators</h2>
<p>The following operators are supported:</p>
<ul>
<li>Conv</li>
<li>Relu</li>
<li>Constant</li>
<li>MaxPool</li>
<li>AveragePool</li>
<li>Softmax</li>
<li>Sigmoid</li>
<li>Add</li>
<li>MatMul</li>
<li>BatchNormalization</li>
<li>Concat</li>
<li>Flatten</li>
<li>Add</li>
<li>Gemm</li>
<li>Reshape</li>
<li>Sum</li>
<li>Cos</li>
<li>Cosh</li>
<li>Sin</li>
<li>Sinh</li>
<li>Tan</li>
<li>Tanh</li>
<li>Acos</li>
<li>Acosh</li>
<li>Asin</li>
<li>Asinh</li>
<li>Atan</li>
<li>Atanh</li>
<li>Selu</li>
<li>Elu</li>
<li>Equal</li>
<li>Less</li>
<li>Sign</li>
<li>Div</li>
<li>Sub</li>
<li>Sqrt</li>
<li>Log</li>
<li>Greater</li>
<li>HardSigmoid</li>
<li>Identity</li>
<li>Softplus</li>
<li>Softsign</li>
<li>Mean</li>
<li>Pow</li>
<li>Clip</li>
<li>PRelu</li>
<li>Mul</li>
<li>Transpose</li>
<li>Max</li>
<li>Min</li>
<li>Shape</li>
<li>And</li>
<li>Or</li>
<li>Xor</li>
<li>Not</li>
<li>Neg</li>
<li>Reciprocal</li>
<li>LeakyRelu</li>
<li>GlobalAveragePool</li>
<li>ConstantOfShape</li>
<li>Dropout</li>
<li>ReduceSum</li>
<li>ReduceMean</li>
<li>LeakyRelu</li>
<li>GlobalAveragePool</li>
<li>Squeeze</li>
<li>Unsqueeze</li>
<li>Slice</li>
<li>Ceil</li>
<li>Split</li>
<li>Gather</li>
<li>Tile</li>
<li>NonZero</li>
<li>Cast</li>
<li>OneHot</li>
</ul>
<h3><a class="anchor" aria-hidden="true" id="special-comments-for-onnx-backend"></a><a href="#special-comments-for-onnx-backend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Special comments for ONNX backend</h3>
<ul>
<li><p>Conv, MaxPool and AveragePool</p>
<p>Input must be 1d<code>(N*C*H)</code> and 2d(<code>N*C*H*W</code>) shape and <code>dilation</code> must be 1.</p></li>
<li><p>BatchNormalization</p>
<p><code>epsilon</code> is 1e-05 and cannot be changed.</p></li>
<li><p>Cast</p>
<p>Only support float32 and int32, other types are casted to these two types.</p></li>
<li><p>Squeeze and Unsqueeze</p>
<p>If you encounter errors when you <code>Squeeze</code> or <code>Unsqueeze</code> between <code>Tensor</code> and
Scalar, please report to us.</p></li>
<li><p>Empty tensor Empty tensor is illegal in SINGA.</p></li>
</ul>
<h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
<p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are three main
class, <code>SingaFrontend</code> and <code>SingaBackend</code> and <code>SingaRep</code>. <code>SingaFrontend</code>
translates a SINGA model to ONNX model; <code>SingaBackend</code> translates a ONNX model
to <code>SingaRep</code> object which stores all SINGA operators and tensors(the tensor in
this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run like a SINGA model.</p>
<h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3>
<p>The entry function of <code>SingaFrontend</code> is <code>singa_to_onnx_model</code> which also is
called <code>to_onnx</code>. <code>singa_to_onnx_model</code> creates the ONNX model, and it also
create a ONNX graph by using <code>singa_to_onnx_graph</code>.</p>
<p><code>singa_to_onnx_graph</code> accepts the output of the model, and recursively iterate
the SINGA model's graph from the output to get all operators to form a queue.
The input and intermediate tensors, i.e, trainable weights, of the SINGA model
is picked up at the same time. The input is stored in <code>onnx_model.graph.input</code>;
the output is stored in <code>onnx_model.graph.output</code>; and the trainable weights are
stored in <code>onnx_model.graph.initializer</code>.</p>
<p>Then the SINGA operator in the queue is translated to ONNX operators one by one.
<code>_rename_operators</code> defines the operators name mapping between SINGA and ONNX.
<code>_special_operators</code> defines which function to be used to translate the
operator.</p>
<p>In addition, some operators in SINGA has different definition with ONNX, that
is, ONNX regards some attributes of SINGA operators as input, so
<code>_unhandled_operators</code> defines which function to handle the special operator.</p>
<p>Since the bool type is regarded as int32 in SINGA, <code>_bool_operators</code> defines the
operators to be changed as bool type.</p>
<h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3>
<p>The entry function of <code>SingaBackend</code> is <code>prepare</code> which checks the version of
ONNX model and call <code>_onnx_model_to_singa_net</code> then.</p>
<p>The purpose of <code>_onnx_model_to_singa_net</code> is to get SINGA tensors and operators.
The tensors are stored in a dictionary by their name in ONNX, and operators are
stored in queue by the form of
<code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>. For each
operator, <code>name</code> is its ONNX node name; <code>op</code> is the ONNX node; <code>forward</code> is the
SINGA operator's forward function; <code>handle</code> is prepared for some special
operators such as Conv and Pooling which has <code>handle</code> object.</p>
<p>The first step of <code>_onnx_model_to_singa_net</code> is to call <code>_init_graph_parameter</code>
to get all tensors within the model. For trainable weights, it can init SINGA
<code>Tensor</code> from <code>onnx_model.graph.initializer</code>. Please note, the weights may also
be stored within graph's input or a ONNX node called <code>Constant</code>, SINGA can also
handle these.</p>
<p>Though all weights are stored within ONNX model, the input of the model is
unknown but its shape and type. So SINGA support two ways to init input, 1,
generate random tensor by its shape and type, 2, allow the user to assign the
input. The first way works fine for most models, however, for some model such as
bert, the indices of matrix cannot be random generated otherwise it will incurs
errors.</p>
<p>Then, <code>_onnx_model_to_singa_net</code> iterators all nodes within ONNX graph to
translate it to SIGNA operators. Also, <code>_rename_operators</code> defines the operators
name mapping between SINGA and ONNX. <code>_special_operators</code> defines which function
to be used to translate the operator. <code>_run_node</code> runs the generated SINGA model
by its input tensors and store its output tensors for being used by later
operators.</p>
<p>This class finally return a <code>SingaRep</code> object and stores all SINGA tensors and
operators within it.</p>
<h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3>
<p><code>SingaBackend</code> stores all SINGA tensors and operators. <code>run</code> accepts the input
of the model and run the SINGA operators one by one following the operators
queue. The user can use <code>last_layers</code> to decide to run the model till the last
few layers. Set <code>all_outputs</code> as <code>False</code> to get only the final output, <code>True</code> to
also get all the intermediate output.</p>
</span></div></article></div><div class="docLastUpdate"><em>Last updated on 4/5/2020</em></div><div class="docs-prevnext"></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inferernce-singa-model">Inferernce SINGA model</a></li><li><a href="#saving-an-onnx-model-from-singa">Saving an ONNX Model from SINGA</a></li><li><a href="#re-training-a-onnx-model">Re-training a ONNX model</a></li><li><a href="#transfer-learning-a-onnx-model">Transfer-learning a ONNX model</a></li></ul></li><li><a href="#example-onnx-mnist-on-singa">Example: ONNX mnist on SINGA</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/#">API Reference (coming soon)</a><a href="/docs/model-zoo-cnn-cifar10">Model Zoo</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/news">SINGA News</a><a href="https://github.com/apache/singa-doc">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa-doc" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script></body></html>