| <!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>ONNX · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="<!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->"/><meta name="docsearch:version" content="3.0.0"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="ONNX · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://singa.apache.org/"/><meta property="og:description" content="<!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->"/><meta property="og:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.css"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://singa.apache.org/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://singa.apache.org/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script> |
| document.addEventListener('DOMContentLoaded', function() { |
| addBackToTop( |
| {"zIndex":100} |
| ) |
| }); |
| </script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>3.0.0</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/3.0.0/installation" target="_self">Docs</a></li><li class=""><a href="/docs/3.0.0/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class="navSearchWrapper reactNavSearchWrapper"><input type="text" id="search_input_react" placeholder="Search" title="Search"/></li><li class=""><a href="https://github.com/apache/singa" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i>›</i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.0.0/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.0.0/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/tensor">Tensor</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/autograd">Autograd</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/graph">Computational Graph</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/3.0.0/onnx">ONNX</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/dist-train">Distributed Training</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.0.0/downloads">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/3.0.0/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script> |
| var coll = document.getElementsByClassName('collapsible'); |
| var checkActiveCategory = true; |
| for (var i = 0; i < coll.length; i++) { |
| var links = coll[i].nextElementSibling.getElementsByTagName('*'); |
| if (checkActiveCategory){ |
| for (var j = 0; j < links.length; j++) { |
| if (links[j].classList.contains('navListItemActive')){ |
| coll[i].nextElementSibling.classList.toggle('hide'); |
| coll[i].childNodes[1].classList.toggle('rotate'); |
| checkActiveCategory = false; |
| break; |
| } |
| } |
| } |
| |
| coll[i].addEventListener('click', function() { |
| var arrow = this.childNodes[1]; |
| arrow.classList.toggle('rotate'); |
| var content = this.nextElementSibling; |
| content.classList.toggle('hide'); |
| }); |
| } |
| |
| document.addEventListener('DOMContentLoaded', function() { |
| createToggler('#navToggler', '#docsNav', 'docsSliderActive'); |
| createToggler('#tocToggler', 'body', 'tocActive'); |
| |
| var headings = document.querySelector('.toc-headings'); |
| headings && headings.addEventListener('click', function(event) { |
| var el = event.target; |
| while(el !== headings){ |
| if (el.tagName === 'A') { |
| document.body.classList.remove('tocActive'); |
| break; |
| } else{ |
| el = el.parentNode; |
| } |
| } |
| }, false); |
| |
| function createToggler(togglerSelector, targetSelector, className) { |
| var toggler = document.querySelector(togglerSelector); |
| var target = document.querySelector(targetSelector); |
| |
| if (!toggler) { |
| return; |
| } |
| |
| toggler.onclick = function(event) { |
| event.preventDefault(); |
| |
| target.classList.toggle(className); |
| }; |
| } |
| }); |
| </script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs-site/docs/onnx.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">ONNX</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> |
| <p><a href="https://onnx.ai/">ONNX</a> is an open representation format for machine learning |
| models, which enables AI developers to use models across different libraries and |
| tools. SINGA supports loading ONNX format models for training and inference, and |
| saving models defined using SINGA APIs (e.g., <a href="./module">Module</a>) into ONNX |
| format.</p> |
| <p>SINGA has been tested with the following |
| <a href="https://github.com/onnx/onnx/blob/master/docs/Versioning.md">version</a> of ONNX.</p> |
| <table> |
| <thead> |
| <tr><th>ONNX version</th><th>File format version</th><th>Opset version ai.onnx</th><th>Opset version ai.onnx.ml</th><th>Opset version ai.onnx.training</th></tr> |
| </thead> |
| <tbody> |
| <tr><td>1.6.0</td><td>6</td><td>11</td><td>2</td><td>-</td></tr> |
| </tbody> |
| </table> |
| <h2><a class="anchor" aria-hidden="true" id="general-usage"></a><a href="#general-usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>General usage</h2> |
| <h3><a class="anchor" aria-hidden="true" id="loading-an-onnx-model-into-singa"></a><a href="#loading-an-onnx-model-into-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Loading an ONNX Model into SINGA</h3> |
| <p>After loading an ONNX model from disk by <code>onnx.load</code>, you need to update the |
| model's batchsize, since for most models, they use a placeholder to represent |
| its batchsize. We give an example here, as <code>update_batch_size</code>. You only need to |
| update the batchsize of input and output, the shape of internal tensors will be |
| inferred automatically.</p> |
| <p>Then, you can prepare the SINGA model by using <code>sonnx.prepare</code>. This function |
| iterates and translates all the nodes within the ONNX model's graph into SINGA |
| operators, loads all stored weights and infers each intermediate tensor's shape.</p> |
| <pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx |
| from singa <span class="hljs-built_in">import</span> device |
| from singa <span class="hljs-built_in">import</span> sonnx |
| |
| <span class="hljs-comment"># if the input has multiple tensors? can put this function inside prepare()?</span> |
| def update_batch_size(onnx_model, batch_size): |
| <span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>] |
| model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size |
| <span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>] |
| model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size |
| return onnx_model |
| |
| |
| <span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span> |
| <span class="hljs-attr">onnx_model</span> = onnx.load(model_path) |
| |
| <span class="hljs-comment"># set batch size</span> |
| <span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>) |
| |
| <span class="hljs-comment"># convert onnx graph nodes into SINGA operators</span> |
| <span class="hljs-attr">dev</span> = device.create_cuda_gpu() |
| <span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span> |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="inference-singa-model"></a><a href="#inference-singa-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference SINGA model</h3> |
| <p>Once the model is created, you can do inference by calling <code>sg_ir.run</code>. The |
| input and output must be SINGA <code>Tensor</code> instances. Since SINGA model returns the |
| output as a list, if there is only one output, you just need to take the first |
| element from the output.</p> |
| <pre><code class="hljs css language-python3"><span class="hljs-comment"># can warp the following code in prepare()</span> |
| <span class="hljs-comment"># and provide a flag training=True/False?</span> |
| |
| <span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span> |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>: |
| <span class="hljs-keyword">self</span>.sg_ir = sg_ir |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>: |
| <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] |
| |
| |
| data = get_dataset() |
| x = tensor.Tensor(device=dev, data=data) |
| |
| model = Infer(sg_ir) |
| y = model.forward(x) |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="saving-singa-model-into-onnx-format"></a><a href="#saving-singa-model-into-onnx-format" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Saving SINGA model into ONNX Format</h3> |
| <p>Given the input tensors and the output tensors generated by the operators the |
| model, you can trace back all internal operations. Therefore, a SINGA model is |
| defined by the input and outputs tensors. To export a SINGA model into ONNX |
| format, you just need to provide the input and output tensor list.</p> |
| <pre><code class="hljs css language-python3"># <span class="hljs-symbol">x</span> is the input tensor, <span class="hljs-symbol">y</span> is the output tensor |
| sonnx.to_onnx([<span class="hljs-symbol">x</span>], [<span class="hljs-symbol">y</span>]) |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="re-training-an-onnx-model"></a><a href="#re-training-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training an ONNX model</h3> |
| <p>To train (or refine) an ONNX model using SINGA, you need to set the internal |
| tensors to be trainable</p> |
| <pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span> |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span> |
| self.sg_ir = sg_ir |
| <span class="hljs-comment">## can wrap these codes in sonnx?</span> |
| <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items(): |
| <span class="hljs-comment"># allow the tensors to be updated</span> |
| tens.requires_grad = <span class="hljs-literal">True</span> |
| tens.stores_grad = <span class="hljs-literal">True</span> |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span> |
| <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] |
| |
| autograd.training = <span class="hljs-literal">False</span> |
| model = Infer(sg_ir) |
| |
| autograd.training = <span class="hljs-literal">True</span> |
| <span class="hljs-comment"># then you training the model like normal</span> |
| <span class="hljs-comment"># give more details??</span> |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="transfer-learning-an-onnx-model"></a><a href="#transfer-learning-an-onnx-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer-learning an ONNX model</h3> |
| <p>You also can append some layers to the end of ONNX model to do |
| transfer-learning. The <code>last_layers</code> means you cut the ONNX layers from [0, |
| last_layers]. Then you can append more layers by the normal SINGA model.</p> |
| <pre><code class="hljs css language-python3">class Trans: |
| |
| def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers): |
| <span class="hljs-literal">self</span>.sg_ir = sg_ir |
| <span class="hljs-literal">self</span>.last_layers = last_layers |
| <span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False) |
| <span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False) |
| <span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False) |
| |
| def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>): |
| <span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>] |
| <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>) |
| <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>) |
| <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>) |
| <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>) |
| <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>) |
| <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>) |
| <span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span> |
| |
| autograd.training = False |
| model = Trans(sg_ir, <span class="hljs-number">-1</span>) |
| |
| # <span class="hljs-keyword">then</span> you training the model like normal |
| </code></pre> |
| <h2><a class="anchor" aria-hidden="true" id="a-full-example"></a><a href="#a-full-example" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>A Full Example</h2> |
| <p>This part introduces the usage of SINGA ONNX by using the mnist example. In this |
| section, the examples of how to export, load, inference, re-training, and |
| transfer-learning the minist model are displayed. You can try this part |
| <a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">here</a>.</p> |
| <h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3> |
| <p>Firstly, you need to import some necessary libraries and define some auxiliary |
| functions for downloading and preprocessing the dataset:</p> |
| <pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os |
| <span class="hljs-keyword">import</span> urllib.request |
| <span class="hljs-keyword">import</span> gzip |
| <span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np |
| <span class="hljs-keyword">import</span> codecs |
| |
| <span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device |
| <span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor |
| <span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt |
| <span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd |
| <span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx |
| <span class="hljs-keyword">import</span> onnx |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span> |
| train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span> |
| train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span> |
| valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span> |
| valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span> |
| train_x = read_image_file(check_exist_or_download(train_x_url)).astype( |
| np.float32) |
| train_y = read_label_file(check_exist_or_download(train_y_url)).astype( |
| np.float32) |
| valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype( |
| np.float32) |
| valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype( |
| np.float32) |
| <span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span> |
| |
| download_dir = <span class="hljs-string">'/tmp/'</span> |
| |
| name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>] |
| filename = os.path.join(download_dir, name) |
| <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename): |
| print(<span class="hljs-string">"Downloading %s"</span> % url) |
| urllib.request.urlretrieve(url, filename) |
| <span class="hljs-keyword">return</span> filename |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span> |
| <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f: |
| data = f.read() |
| <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span> |
| length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>]) |
| parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape( |
| (length)) |
| <span class="hljs-keyword">return</span> parsed |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span> |
| <span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>) |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span> |
| <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f: |
| data = f.read() |
| <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span> |
| length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>]) |
| num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>]) |
| num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>]) |
| parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape( |
| (length, <span class="hljs-number">1</span>, num_rows, num_cols)) |
| <span class="hljs-keyword">return</span> parsed |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span> |
| y = np.array(y, dtype=<span class="hljs-string">"int"</span>) |
| n = y.shape[<span class="hljs-number">0</span>] |
| categorical = np.zeros((n, num_classes)) |
| categorical[np.arange(n), y] = <span class="hljs-number">1</span> |
| categorical = categorical.astype(np.float32) |
| <span class="hljs-keyword">return</span> categorical |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3> |
| <p>Then you can define a class called <strong>CNN</strong> to construct the mnist model which |
| consists of several convolution, pooling, fully connection and relu layers. You |
| can also define a function to calculate the <strong>accuracy</strong> of our result. Finally, |
| you can define a <strong>train</strong> and a <strong>test</strong> function to handle the training and |
| prediction process.</p> |
| <pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span> |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span> |
| self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>) |
| self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>) |
| self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>) |
| self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>) |
| self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>) |
| self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>) |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span> |
| y = self.conv1(x) |
| y = autograd.relu(y) |
| y = self.pooling1(y) |
| y = self.conv2(y) |
| y = autograd.relu(y) |
| y = self.pooling2(y) |
| y = autograd.flatten(y) |
| y = self.linear1(y) |
| y = autograd.relu(y) |
| y = self.linear2(y) |
| <span class="hljs-keyword">return</span> y |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span> |
| y = np.argmax(pred, axis=<span class="hljs-number">1</span>) |
| t = np.argmax(target, axis=<span class="hljs-number">1</span>) |
| a = y == t |
| <span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t)) |
| |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model, |
| x, |
| y, |
| epochs=<span class="hljs-number">1</span>, |
| batch_size=<span class="hljs-number">64</span>, |
| dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span> |
| batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size |
| |
| <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs): |
| <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number): |
| l_idx = b * batch_size |
| r_idx = (b + <span class="hljs-number">1</span>) * batch_size |
| |
| x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) |
| target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) |
| |
| output_batch = model.forward(x_batch) |
| <span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span> |
| <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span> |
| |
| loss = autograd.softmax_cross_entropy(output_batch, target_batch) |
| accuracy_rate = accuracy(tensor.to_numpy(output_batch), |
| tensor.to_numpy(target_batch)) |
| |
| sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>) |
| <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss): |
| sgd.update(p, gp) |
| sgd.step() |
| |
| <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>: |
| print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> % |
| (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>])) |
| print(<span class="hljs-string">"training completed"</span>) |
| <span class="hljs-keyword">return</span> x_batch, output_batch |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span> |
| batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size |
| |
| result = <span class="hljs-number">0</span> |
| <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number): |
| l_idx = b * batch_size |
| r_idx = (b + <span class="hljs-number">1</span>) * batch_size |
| |
| x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) |
| target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) |
| |
| output_batch = model.forward(x_batch) |
| result += accuracy(tensor.to_numpy(output_batch), |
| tensor.to_numpy(target_batch)) |
| |
| print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number)) |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3> |
| <p>Now, you can train the mnist model and export its onnx model by calling the |
| <strong>soonx.to_onnx</strong> function.</p> |
| <pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span> |
| <span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y]) |
| |
| <span class="hljs-comment"># create device</span> |
| dev = device.create_cuda_gpu() |
| <span class="hljs-comment">#dev = device.get_default_device()</span> |
| <span class="hljs-comment"># create model</span> |
| model = CNN() |
| <span class="hljs-comment"># load data</span> |
| train_x, train_y, valid_x, valid_y = load_dataset() |
| <span class="hljs-comment"># normalization</span> |
| train_x = train_x / <span class="hljs-number">255</span> |
| valid_x = valid_x / <span class="hljs-number">255</span> |
| train_y = to_categorical(train_y, <span class="hljs-number">10</span>) |
| valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>) |
| <span class="hljs-comment"># do training</span> |
| autograd.training = <span class="hljs-literal">True</span> |
| x, y = train(model, train_x, train_y, dev=dev) |
| onnx_model = make_onnx(x, y) |
| <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span> |
| |
| <span class="hljs-comment"># Save the ONNX model</span> |
| model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>) |
| onnx.save(onnx_model, model_path) |
| print(<span class="hljs-string">'The model is saved.'</span>) |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3> |
| <p>After you export the onnx model, you can find a file called <strong>mnist.onnx</strong> in |
| the '/tmp' directory, this model, therefore, can be imported by other libraries. |
| Now, if you want to import this onnx model into singa again and do the inference |
| using the validation dataset, you can define a class called <strong>Infer</strong>, the |
| forward function of Infer will be called by the test function to do inference |
| for validation dataset. By the way, you should set the label of training to |
| <strong>False</strong> to fix the gradient of autograd operators.</p> |
| <p>When import the onnx model, you need to call <strong>onnx.load</strong> to load the onnx |
| model firstly. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to |
| parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a |
| singa graph within it, and then you can run an step of inference by feeding |
| input to its run function.</p> |
| <pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span> |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span> |
| self.sg_ir = sg_ir |
| <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items(): |
| <span class="hljs-comment"># allow the tensors to be updated</span> |
| tens.requires_grad = <span class="hljs-literal">True</span> |
| tens.stores_grad= <span class="hljs-literal">True</span> |
| sg_ir.tensor_map[idx] = tens |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span> |
| <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span> |
| |
| <span class="hljs-comment"># load the ONNX model</span> |
| onnx_model = onnx.load(model_path) |
| sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span> |
| |
| <span class="hljs-comment"># inference</span> |
| autograd.training = <span class="hljs-literal">False</span> |
| print(<span class="hljs-string">'The inference result is:'</span>) |
| test(Infer(sg_ir), valid_x, valid_y, dev=dev) |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3> |
| <p>Assume after import the model, you want to re-train the model again, we can |
| define a function called <strong>re_train</strong>. Before we call this re_train function, we |
| should set the label of training to <strong>True</strong> to make the autograde operators |
| update their gradient. And after we finish the training, we set it as <strong>False</strong> |
| again to call the test function doing inference.</p> |
| <pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir, |
| x, |
| y, |
| epochs=<span class="hljs-number">1</span>, |
| batch_size=<span class="hljs-number">64</span>, |
| dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span> |
| batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size |
| |
| new_model = Infer(sg_ir) |
| |
| <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs): |
| <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number): |
| l_idx = b * batch_size |
| r_idx = (b + <span class="hljs-number">1</span>) * batch_size |
| |
| x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) |
| target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) |
| |
| output_batch = new_model.forward(x_batch) |
| |
| loss = autograd.softmax_cross_entropy(output_batch, target_batch) |
| accuracy_rate = accuracy(tensor.to_numpy(output_batch), |
| tensor.to_numpy(target_batch)) |
| |
| sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>) |
| <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss): |
| sgd.update(p, gp) |
| sgd.step() |
| |
| <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>: |
| print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> % |
| (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>])) |
| print(<span class="hljs-string">"re-training completed"</span>) |
| <span class="hljs-keyword">return</span> new_model |
| |
| <span class="hljs-comment"># load the ONNX model</span> |
| onnx_model = onnx.load(model_path) |
| sg_ir = sonnx.prepare(onnx_model, device=dev) |
| |
| <span class="hljs-comment"># re-training</span> |
| autograd.training = <span class="hljs-literal">True</span> |
| new_model = re_train(sg_ir, train_x, train_y, dev=dev) |
| autograd.training = <span class="hljs-literal">False</span> |
| test(new_model, valid_x, valid_y, dev=dev) |
| </code></pre> |
| <h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3> |
| <p>Finally, if we want to do transfer-learning, we can define a function called |
| <strong>Trans</strong> to append some layers after the onnx model. For demonstration, the |
| code only appends several linear(fully connection) and relu after the onnx |
| model. You can define a transfer_learning function to handle the training |
| process of the transfer-learning model. And the label of training is the same as |
| the previous one.</p> |
| <pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span> |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span> |
| self.sg_ir = sg_ir |
| self.last_layers = last_layers |
| self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>) |
| self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>) |
| self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>) |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span> |
| y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>] |
| y = self.append_linear1(y) |
| y = autograd.relu(y) |
| y = self.append_linear2(y) |
| y = autograd.relu(y) |
| y = self.append_linear3(y) |
| y = autograd.relu(y) |
| <span class="hljs-keyword">return</span> y |
| |
| <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir, |
| x, |
| y, |
| epochs=<span class="hljs-number">1</span>, |
| batch_size=<span class="hljs-number">64</span>, |
| dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span> |
| batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size |
| |
| trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>) |
| |
| <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs): |
| <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number): |
| l_idx = b * batch_size |
| r_idx = (b + <span class="hljs-number">1</span>) * batch_size |
| |
| x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) |
| target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) |
| output_batch = trans_model.forward(x_batch) |
| |
| loss = autograd.softmax_cross_entropy(output_batch, target_batch) |
| accuracy_rate = accuracy(tensor.to_numpy(output_batch), |
| tensor.to_numpy(target_batch)) |
| |
| sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>) |
| <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss): |
| sgd.update(p, gp) |
| sgd.step() |
| |
| <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>: |
| print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> % |
| (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>])) |
| print(<span class="hljs-string">"transfer-learning completed"</span>) |
| <span class="hljs-keyword">return</span> trans_mode |
| |
| <span class="hljs-comment"># load the ONNX model</span> |
| onnx_model = onnx.load(model_path) |
| sg_ir = sonnx.prepare(onnx_model, device=dev) |
| |
| <span class="hljs-comment"># transfer-learning</span> |
| autograd.training = <span class="hljs-literal">True</span> |
| new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev) |
| autograd.training = <span class="hljs-literal">False</span> |
| test(new_model, valid_x, valid_y, dev=dev) |
| </code></pre> |
| <h2><a class="anchor" aria-hidden="true" id="onnx-model-zoo"></a><a href="#onnx-model-zoo" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX model zoo</h2> |
| <p>The <a href="https://github.com/onnx/models">ONNX Model Zoo</a> is a collection of |
| pre-trained, state-of-the-art models in the ONNX format contributed by community |
| members. SINGA has supported several CV and NLP models now. More models are |
| going to be supported soon.</p> |
| <h3><a class="anchor" aria-hidden="true" id="image-classification"></a><a href="#image-classification" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Image Classification</h3> |
| <p>This collection of models take images as input, then classifies the major |
| objects in the images into 1000 object categories such as keyboard, mouse, |
| pencil, and many animals.</p> |
| <table> |
| <thead> |
| <tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr> |
| </thead> |
| <tbody> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/mobilenet">MobileNet</a></b></td><td><a href="https://arxiv.org/abs/1801.04381">Sandler et al.</a></td><td>Light-weight deep neural network best suited for mobile and embedded vision applications. <br>Top-5 error from paper - ~10%</td><td><a href="https://colab.research.google.com/drive/1HsixqJMIpKyEPhkbB8jy7NwNEFEAUWAf"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/resnet">ResNet18</a></b></td><td><a href="https://arxiv.org/abs/1512.03385">He et al.</a></td><td>A CNN model (up to 152 layers). Uses shortcut connections to achieve higher accuracy when classifying images. <br> Top-5 error from paper - ~3.6%</td><td><a href="https://colab.research.google.com/drive/1u1RYefSsVbiP4I-5wiBKHjsT9L0FxLm9"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/vgg">VGG16</a></b></td><td><a href="https://arxiv.org/abs/1409.1556">Simonyan et al.</a></td><td>Deep CNN model(up to 19 layers). Similar to AlexNet but uses multiple smaller kernel-sized filters that provides more accuracy when classifying images. <br>Top-5 error from paper - ~8%</td><td><a href="https://colab.research.google.com/drive/14kxgRKtbjPCKKsDJVNi3AvTev81Gp_Ds"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| </tbody> |
| </table> |
| <h3><a class="anchor" aria-hidden="true" id="object-detection"></a><a href="#object-detection" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Object Detection</h3> |
| <p>Object detection models detect the presence of multiple objects in an image and |
| segment out areas of the image where the objects are detected.</p> |
| <table> |
| <thead> |
| <tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr> |
| </thead> |
| <tbody> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny_yolov2">Tiny YOLOv2</a></b></td><td><a href="https://arxiv.org/pdf/1612.08242.pdf">Redmon et al.</a></td><td>A real-time CNN for object detection that detects 20 different classes. A smaller version of the more complex full YOLOv2 network.</td><td><a href="https://colab.research.google.com/drive/11V4I6cRjIJNUv5ZGsEGwqHuoQEie6b1T"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| </tbody> |
| </table> |
| <h3><a class="anchor" aria-hidden="true" id="face-analysis"></a><a href="#face-analysis" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Face Analysis</h3> |
| <p>Face detection models identify and/or recognize human faces and emotions in |
| given images.</p> |
| <table> |
| <thead> |
| <tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr> |
| </thead> |
| <tbody> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/body_analysis/arcface">ArcFace</a></b></td><td><a href="https://arxiv.org/abs/1801.07698">Deng et al.</a></td><td>A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for input face images.</td><td><a href="https://colab.research.google.com/drive/1qanaqUKGIDtifdzEzJOHjEj4kYzA9uJC"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/body_analysis/emotion_ferplus">Emotion FerPlus</a></b></td><td><a href="https://arxiv.org/abs/1608.01041">Barsoum et al.</a></td><td>Deep CNN for emotion recognition trained on images of faces.</td><td><a href="https://colab.research.google.com/drive/1XHtBQGRhe58PDi4LGYJzYueWBeWbO23r"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| </tbody> |
| </table> |
| <h3><a class="anchor" aria-hidden="true" id="machine-comprehension"></a><a href="#machine-comprehension" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Machine Comprehension</h3> |
| <p>This subset of natural language processing models that answer questions about a |
| given context paragraph.</p> |
| <table> |
| <thead> |
| <tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr> |
| </thead> |
| <tbody> |
| <tr><td><b><a href="https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad">BERT-Squad</a></b></td><td><a href="https://arxiv.org/pdf/1810.04805.pdf">Devlin et al.</a></td><td>This model answers questions based on the context of the given input paragraph.</td><td><a href="https://colab.research.google.com/drive/1kud-lUPjS_u-TkDAzihBTw0Vqr0FjCE-"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr> |
| </tbody> |
| </table> |
| <h2><a class="anchor" aria-hidden="true" id="supported-operators"></a><a href="#supported-operators" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Supported operators</h2> |
| <p>The following operators are supported:</p> |
| <ul> |
| <li>Conv</li> |
| <li>Relu</li> |
| <li>Constant</li> |
| <li>MaxPool</li> |
| <li>AveragePool</li> |
| <li>Softmax</li> |
| <li>Sigmoid</li> |
| <li>Add</li> |
| <li>MatMul</li> |
| <li>BatchNormalization</li> |
| <li>Concat</li> |
| <li>Flatten</li> |
| <li>Add</li> |
| <li>Gemm</li> |
| <li>Reshape</li> |
| <li>Sum</li> |
| <li>Cos</li> |
| <li>Cosh</li> |
| <li>Sin</li> |
| <li>Sinh</li> |
| <li>Tan</li> |
| <li>Tanh</li> |
| <li>Acos</li> |
| <li>Acosh</li> |
| <li>Asin</li> |
| <li>Asinh</li> |
| <li>Atan</li> |
| <li>Atanh</li> |
| <li>Selu</li> |
| <li>Elu</li> |
| <li>Equal</li> |
| <li>Less</li> |
| <li>Sign</li> |
| <li>Div</li> |
| <li>Sub</li> |
| <li>Sqrt</li> |
| <li>Log</li> |
| <li>Greater</li> |
| <li>HardSigmoid</li> |
| <li>Identity</li> |
| <li>Softplus</li> |
| <li>Softsign</li> |
| <li>Mean</li> |
| <li>Pow</li> |
| <li>Clip</li> |
| <li>PRelu</li> |
| <li>Mul</li> |
| <li>Transpose</li> |
| <li>Max</li> |
| <li>Min</li> |
| <li>Shape</li> |
| <li>And</li> |
| <li>Or</li> |
| <li>Xor</li> |
| <li>Not</li> |
| <li>Neg</li> |
| <li>Reciprocal</li> |
| <li>LeakyRelu</li> |
| <li>GlobalAveragePool</li> |
| <li>ConstantOfShape</li> |
| <li>Dropout</li> |
| <li>ReduceSum</li> |
| <li>ReduceMean</li> |
| <li>LeakyRelu</li> |
| <li>GlobalAveragePool</li> |
| <li>Squeeze</li> |
| <li>Unsqueeze</li> |
| <li>Slice</li> |
| <li>Ceil</li> |
| <li>Split</li> |
| <li>Gather</li> |
| <li>Tile</li> |
| <li>NonZero</li> |
| <li>Cast</li> |
| <li>OneHot</li> |
| </ul> |
| <h3><a class="anchor" aria-hidden="true" id="special-comments-for-onnx-backend"></a><a href="#special-comments-for-onnx-backend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Special comments for ONNX backend</h3> |
| <ul> |
| <li><p>Conv, MaxPool and AveragePool</p> |
| <p>Input must be 1d<code>(N*C*H)</code> and 2d(<code>N*C*H*W</code>) shape and <code>dilation</code> must be 1.</p></li> |
| <li><p>BatchNormalization</p> |
| <p><code>epsilon</code> is 1e-05 and cannot be changed.</p></li> |
| <li><p>Cast</p> |
| <p>Only support float32 and int32, other types are casted to these two types.</p></li> |
| <li><p>Squeeze and Unsqueeze</p> |
| <p>If you encounter errors when you <code>Squeeze</code> or <code>Unsqueeze</code> between <code>Tensor</code> and |
| Scalar, please report to us.</p></li> |
| <li><p>Empty tensor Empty tensor is illegal in SINGA.</p></li> |
| </ul> |
| <h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2> |
| <p>The code of SINGA ONNX locates at <code>python/singa/soonx.py</code>. There are three main |
| class, <code>SingaFrontend</code> and <code>SingaBackend</code> and <code>SingaRep</code>. <code>SingaFrontend</code> |
| translates a SINGA model to ONNX model; <code>SingaBackend</code> translates a ONNX model |
| to <code>SingaRep</code> object which stores all SINGA operators and tensors(the tensor in |
| this doc means SINGA <code>Tensor</code>); <code>SingaRep</code> can be run like a SINGA model.</p> |
| <h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3> |
| <p>The entry function of <code>SingaFrontend</code> is <code>singa_to_onnx_model</code> which also is |
| called <code>to_onnx</code>. <code>singa_to_onnx_model</code> creates the ONNX model, and it also |
| create a ONNX graph by using <code>singa_to_onnx_graph</code>.</p> |
| <p><code>singa_to_onnx_graph</code> accepts the output of the model, and recursively iterate |
| the SINGA model's graph from the output to get all operators to form a queue. |
| The input and intermediate tensors, i.e, trainable weights, of the SINGA model |
| is picked up at the same time. The input is stored in <code>onnx_model.graph.input</code>; |
| the output is stored in <code>onnx_model.graph.output</code>; and the trainable weights are |
| stored in <code>onnx_model.graph.initializer</code>.</p> |
| <p>Then the SINGA operator in the queue is translated to ONNX operators one by one. |
| <code>_rename_operators</code> defines the operators name mapping between SINGA and ONNX. |
| <code>_special_operators</code> defines which function to be used to translate the |
| operator.</p> |
| <p>In addition, some operators in SINGA has different definition with ONNX, that |
| is, ONNX regards some attributes of SINGA operators as input, so |
| <code>_unhandled_operators</code> defines which function to handle the special operator.</p> |
| <p>Since the bool type is regarded as int32 in SINGA, <code>_bool_operators</code> defines the |
| operators to be changed as bool type.</p> |
| <h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3> |
| <p>The entry function of <code>SingaBackend</code> is <code>prepare</code> which checks the version of |
| ONNX model and call <code>_onnx_model_to_singa_net</code> then.</p> |
| <p>The purpose of <code>_onnx_model_to_singa_net</code> is to get SINGA tensors and operators. |
| The tensors are stored in a dictionary by their name in ONNX, and operators are |
| stored in queue by the form of |
| <code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>. For each |
| operator, <code>name</code> is its ONNX node name; <code>op</code> is the ONNX node; <code>forward</code> is the |
| SINGA operator's forward function; <code>handle</code> is prepared for some special |
| operators such as Conv and Pooling which has <code>handle</code> object.</p> |
| <p>The first step of <code>_onnx_model_to_singa_net</code> is to call <code>_init_graph_parameter</code> |
| to get all tensors within the model. For trainable weights, it can init SINGA |
| <code>Tensor</code> from <code>onnx_model.graph.initializer</code>. Please note, the weights may also |
| be stored within graph's input or a ONNX node called <code>Constant</code>, SINGA can also |
| handle these.</p> |
| <p>Though all weights are stored within ONNX model, the input of the model is |
| unknown but its shape and type. So SINGA support two ways to init input, 1, |
| generate random tensor by its shape and type, 2, allow the user to assign the |
| input. The first way works fine for most models, however, for some model such as |
| bert, the indices of matrix cannot be random generated otherwise it will incurs |
| errors.</p> |
| <p>Then, <code>_onnx_model_to_singa_net</code> iterators all nodes within ONNX graph to |
| translate it to SIGNA operators. Also, <code>_rename_operators</code> defines the operators |
| name mapping between SINGA and ONNX. <code>_special_operators</code> defines which function |
| to be used to translate the operator. <code>_run_node</code> runs the generated SINGA model |
| by its input tensors and store its output tensors for being used by later |
| operators.</p> |
| <p>This class finally return a <code>SingaRep</code> object and stores all SINGA tensors and |
| operators within it.</p> |
| <h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3> |
| <p><code>SingaBackend</code> stores all SINGA tensors and operators. <code>run</code> accepts the input |
| of the model and run the SINGA operators one by one following the operators |
| queue. The user can use <code>last_layers</code> to decide to run the model till the last |
| few layers. Set <code>all_outputs</code> as <code>False</code> to get only the final output, <code>True</code> to |
| also get all the intermediate output.</p> |
| </span></div></article></div><div class="docLastUpdate"><em>Last updated on 10/04/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/3.0.0/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/3.0.0/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#general-usage">General usage</a><ul class="toc-headings"><li><a href="#loading-an-onnx-model-into-singa">Loading an ONNX Model into SINGA</a></li><li><a href="#inference-singa-model">Inference SINGA model</a></li><li><a href="#saving-singa-model-into-onnx-format">Saving SINGA model into ONNX Format</a></li><li><a href="#re-training-an-onnx-model">Re-training an ONNX model</a></li><li><a href="#transfer-learning-an-onnx-model">Transfer-learning an ONNX model</a></li></ul></li><li><a href="#a-full-example">A Full Example</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#onnx-model-zoo">ONNX model zoo</a><ul class="toc-headings"><li><a href="#image-classification">Image Classification</a></li><li><a href="#object-detection">Object Detection</a></li><li><a href="#face-analysis">Face Analysis</a></li><li><a href="#machine-comprehension">Machine Comprehension</a></li></ul></li><li><a href="#supported-operators">Supported operators</a><ul class="toc-headings"><li><a href="#special-comments-for-onnx-backend">Special comments for ONNX backend</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020 |
| The Apache Software Foundation. All rights reserved. |
| Apache SINGA, Apache, the Apache feather logo, and |
| the Apache SINGA project logos are trademarks of The |
| Apache Software Foundation. All other marks mentioned |
| may be trademarks or registered trademarks of their |
| respective owners.</section></div></footer></div><script type="text/javascript" src="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.js"></script><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script><script> |
| document.addEventListener('keyup', function(e) { |
| if (e.target !== document.body) { |
| return; |
| } |
| // keyCode for '/' (slash) |
| if (e.keyCode === 191) { |
| const search = document.getElementById('search_input_react'); |
| search && search.focus(); |
| } |
| }); |
| </script><script> |
| var search = docsearch({ |
| |
| apiKey: '45202133606c0b5fa6d21cddc4725dd8', |
| indexName: 'apache_singa', |
| inputSelector: '#search_input_react', |
| algoliaOptions: {"facetFilters":["language:en","version:3.0.0"]} |
| }); |
| </script></body></html> |