<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>ONNX · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.  See the NOTICE file distributed with this work for additional information regarding copyright ownership.  The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License.  You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the specific language governing permissions and limitations under the License.  --&gt;"/><meta name="docsearch:version" content="3.1.0_Chinese"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="ONNX · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://singa.apache.org/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.  See the NOTICE file distributed with this work for additional information regarding copyright ownership.  The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License.  You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the specific language governing permissions and limitations under the License.  --&gt;"/><meta property="og:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.css"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://singa.apache.org/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://singa.apache.org/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
        document.addEventListener('DOMContentLoaded', function() {
          addBackToTop(
            {"zIndex":100}
          )
        });
        </script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>3.1.0_Chinese</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/3.1.0_Chinese/installation" target="_self">Docs</a></li><li class=""><a href="/docs/3.1.0_Chinese/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class="navSearchWrapper reactNavSearchWrapper"><input type="text" id="search_input_react" placeholder="Search" title="Search"/></li><li class=""><a href="https://github.com/apache/singa" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i>›</i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/tensor">Tensor</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/autograd">Autograd</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/optimizer">Optimizer</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/graph">Model</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/3.1.0_Chinese/onnx">ONNX</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/dist-train">Distributed Training</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/time-profiling">Time Profiling</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/downloads">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/3.1.0_Chinese/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script>
            var coll = document.getElementsByClassName('collapsible');
            var checkActiveCategory = true;
            for (var i = 0; i < coll.length; i++) {
              var links = coll[i].nextElementSibling.getElementsByTagName('*');
              if (checkActiveCategory){
                for (var j = 0; j < links.length; j++) {
                  if (links[j].classList.contains('navListItemActive')){
                    coll[i].nextElementSibling.classList.toggle('hide');
                    coll[i].childNodes[1].classList.toggle('rotate');
                    checkActiveCategory = false;
                    break;
                  }
                }
              }

              coll[i].addEventListener('click', function() {
                var arrow = this.childNodes[1];
                arrow.classList.toggle('rotate');
                var content = this.nextElementSibling;
                content.classList.toggle('hide');
              });
            }

            document.addEventListener('DOMContentLoaded', function() {
              createToggler('#navToggler', '#docsNav', 'docsSliderActive');
              createToggler('#tocToggler', 'body', 'tocActive');

              var headings = document.querySelector('.toc-headings');
              headings && headings.addEventListener('click', function(event) {
                var el = event.target;
                while(el !== headings){
                  if (el.tagName === 'A') {
                    document.body.classList.remove('tocActive');
                    break;
                  } else{
                    el = el.parentNode;
                  }
                }
              }, false);

              function createToggler(togglerSelector, targetSelector, className) {
                var toggler = document.querySelector(togglerSelector);
                var target = document.querySelector(targetSelector);

                if (!toggler) {
                  return;
                }

                toggler.onclick = function(event) {
                  event.preventDefault();

                  target.classList.toggle(className);
                };
              }
            });
        </script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs-site/docs/onnx.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">ONNX</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.  See the NOTICE file distributed with this work for additional information regarding copyright ownership.  The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.  You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the specific language governing permissions and limitations under the License.  -->
<p><a href="https://onnx.ai/">ONNX</a> 是机器学习模型的开放表示格式，它使AI开发人员能够在不同的库和工具中使用模型。SINGA支持加载ONNX格式模型用于训练和inference，并将使用SINGA API（如<a href="./module">Module</a>）定义的模型保存为ONNX格式。</p>
<p>SINGA在以下<a href="https://github.com/onnx/onnx/blob/master/docs/Versioning.md">版本</a>中的ONNX中测试过。</p>
<table>
<thead>
<tr><th>ONNX version</th><th>File format version</th><th>Opset version ai.onnx</th><th>Opset version ai.onnx.ml</th><th>Opset version ai.onnx.training</th></tr>
</thead>
<tbody>
<tr><td>1.6.0</td><td>6</td><td>11</td><td>2</td><td>-</td></tr>
</tbody>
</table>
<h2><a class="anchor" aria-hidden="true" id="通常用法"></a><a href="#通常用法" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>通常用法</h2>
<h3><a class="anchor" aria-hidden="true" id="从onnx中读取一个model到singa"></a><a href="#从onnx中读取一个model到singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>从ONNX中读取一个Model到SINGA</h3>
<p>在通过 <code>onnx.load</code> 从磁盘加载 ONNX 模型后，您需要更新模型的batch_size，因为对于大多数模型来说，它们使用一个占位符来表示其批处理量。我们在这里举一个例子，若要 <code>update_batch_size</code>，你只需要更新输入和输出的 batch_size，内部的 tensors 的形状会自动推断出来。</p>
<p>然后，您可以使用 <code>sonnx.prepare</code> 来准备 SINGA 模型。该函数将 ONNX 模型图中的所有节点迭代并翻译成 SINGA 运算符，加载所有存储的权重并推断每个中间张量的形状。</p>
<pre><code class="hljs css language-python3"><span class="hljs-built_in">import</span> onnx
from singa <span class="hljs-built_in">import</span> device
from singa <span class="hljs-built_in">import</span> sonnx

<span class="hljs-comment"># if the input has multiple tensors? can put this function inside prepare()?</span>
def update_batch_size(onnx_model, batch_size):
    <span class="hljs-attr">model_input</span> = onnx_model.graph.input[<span class="hljs-number">0</span>]
    model_input.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
    <span class="hljs-attr">model_output</span> = onnx_model.graph.output[<span class="hljs-number">0</span>]
    model_output.type.tensor_type.shape.dim[<span class="hljs-number">0</span>].<span class="hljs-attr">dim_value</span> = batch_size
    return onnx_model


<span class="hljs-attr">model_path</span> = <span class="hljs-string">"PATH/To/ONNX/MODEL"</span>
<span class="hljs-attr">onnx_model</span> = onnx.load(model_path)

<span class="hljs-comment"># set batch size</span>
<span class="hljs-attr">onnx_model</span> = update_batch_size(onnx_model, <span class="hljs-number">1</span>)

<span class="hljs-comment"># convert onnx graph nodes into SINGA operators</span>
<span class="hljs-attr">dev</span> = device.create_cuda_gpu()
<span class="hljs-attr">sg_ir</span> = sonnx.prepare(onnx_model, <span class="hljs-attr">device=dev)</span>
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="inference-singa模型"></a><a href="#inference-singa模型" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference SINGA模型</h3>
<p>一旦创建了模型，就可以通过调用<code>sg_ir.run</code>进行inference。输入和输出必须是SINGA Tensor实例，由于SINGA模型以列表形式返回输出，如果只有一个输出，你只需要从输出中取第一个元素即可。</p>
<pre><code class="hljs css language-python3"><span class="hljs-comment"># can warp the following code in prepare()</span>
<span class="hljs-comment"># and provide a flag training=True/False?</span>

<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>


    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, sg_ir)</span></span>:
        <span class="hljs-keyword">self</span>.sg_ir = sg_ir

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(<span class="hljs-keyword">self</span>, x)</span></span>:
        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]


data = get_dataset()
x = tensor.Tensor(device=dev, data=data)

model = Infer(sg_ir)
y = model.forward(x)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="将singa模型保存成onnx格式"></a><a href="#将singa模型保存成onnx格式" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>将SINGA模型保存成ONNX格式</h3>
<p>给定输入时序和输出时序，由运算符产生的模型，你可以追溯所有内部操作。因此，一个SINGA模型是由输入和输出张量定义的，要将 SINGA 模型导出为 ONNX 格式，您只需提供输入和输出张量列表。</p>
<pre><code class="hljs css language-python3"># <span class="hljs-symbol">x</span> is the input tensor, <span class="hljs-symbol">y</span> is the output tensor
sonnx.to_onnx([<span class="hljs-symbol">x</span>], [<span class="hljs-symbol">y</span>])
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="在onnx模型上重新训练"></a><a href="#在onnx模型上重新训练" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>在ONNX模型上重新训练</h3>
<p>要使用 SINGA 训练（或改进）ONNX 模型，您需要设置内部的张量为可训练状态：</p>
<pre><code class="hljs css language-python3"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
        self.sg_ir = sg_ir
        <span class="hljs-comment">## can wrap these codes in sonnx?</span>
        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
            <span class="hljs-comment"># allow the tensors to be updated</span>
            tens.requires_grad = <span class="hljs-literal">True</span>
            tens.stores_grad = <span class="hljs-literal">True</span>

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>]

autograd.training = <span class="hljs-literal">False</span>
model = Infer(sg_ir)

autograd.training = <span class="hljs-literal">True</span>
<span class="hljs-comment"># then you training the model like normal</span>
<span class="hljs-comment"># give more details??</span>
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="在onnx模型上做迁移学习"></a><a href="#在onnx模型上做迁移学习" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>在ONNX模型上做迁移学习</h3>
<p>您也可以在ONNX模型的最后附加一些图层来进行转移学习。<code>last_layers</code> 意味着您从 [0, last_layers] 切断 ONNX 层。然后您可以通过普通的SINGA模型附加更多的层。</p>
<pre><code class="hljs css language-python3">class Trans:

    def __init__(<span class="hljs-literal">self</span>, sg_ir, last_layers):
        <span class="hljs-literal">self</span>.sg_ir = sg_ir
        <span class="hljs-literal">self</span>.last_layers = last_layers
        <span class="hljs-literal">self</span>.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=False)
        <span class="hljs-literal">self</span>.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=False)
        <span class="hljs-literal">self</span>.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=False)

    def forward(<span class="hljs-literal">self</span>, <span class="hljs-symbol">x</span>):
        <span class="hljs-symbol">y</span> = sg_ir.run([<span class="hljs-symbol">x</span>], last_layers=<span class="hljs-literal">self</span>.last_layers)[<span class="hljs-number">0</span>]
        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear1(<span class="hljs-symbol">y</span>)
        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear2(<span class="hljs-symbol">y</span>)
        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
        <span class="hljs-symbol">y</span> = <span class="hljs-literal">self</span>.append_linear3(<span class="hljs-symbol">y</span>)
        <span class="hljs-symbol">y</span> = autograd.relu(<span class="hljs-symbol">y</span>)
        <span class="hljs-keyword">return</span> <span class="hljs-symbol">y</span>

autograd.training = False
model = Trans(sg_ir, <span class="hljs-number">-1</span>)

# <span class="hljs-keyword">then</span> you training the model like normal
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="一个完整示例"></a><a href="#一个完整示例" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>一个完整示例</h2>
<p>本部分以mnist为例，介绍SINGA ONNX的使用方法。在这部分，将展示如何导出、加载、inference、再训练和迁移学习 mnist 模型的例子。您可以在<a href="https://colab.research.google.com/drive/1-YOfQqqw3HNhS8WpB8xjDQYutRdUdmCq">这里</a>试用这部分内容。</p>
<h3><a class="anchor" aria-hidden="true" id="读取数据集"></a><a href="#读取数据集" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>读取数据集</h3>
<p>首先，你需要导入一些必要的库，并定义一些辅助函数来下载和预处理数据集：</p>
<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> urllib.request
<span class="hljs-keyword">import</span> gzip
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> codecs

<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
<span class="hljs-keyword">import</span> onnx


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
    train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
    train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
    valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
    valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
        np.float32)
    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
        np.float32)
    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
        np.float32)
    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
        np.float32)
    <span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>

    download_dir = <span class="hljs-string">'/tmp/'</span>

    name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
    filename = os.path.join(download_dir, name)
    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
        print(<span class="hljs-string">"Downloading %s"</span> % url)
        urllib.request.urlretrieve(url, filename)
    <span class="hljs-keyword">return</span> filename


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
        data = f.read()
        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
            (length))
        <span class="hljs-keyword">return</span> parsed


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
    <span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
    <span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
        data = f.read()
        <span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
        length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
        num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
        num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
            (length, <span class="hljs-number">1</span>, num_rows, num_cols))
        <span class="hljs-keyword">return</span> parsed


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
    y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
    n = y.shape[<span class="hljs-number">0</span>]
    categorical = np.zeros((n, num_classes))
    categorical[np.arange(n), y] = <span class="hljs-number">1</span>
    categorical = categorical.astype(np.float32)
    <span class="hljs-keyword">return</span> categorical
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="mnist模型"></a><a href="#mnist模型" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST模型</h3>
<p>然后你可以定义一个叫做<strong>CNN</strong>的类来构造mnist模型，这个模型由几个卷积层、池化层、全连接层和relu层组成。你也可以定义一个函数来计算我们结果的<strong>准确性</strong>。最后，你可以定义一个<strong>训练函数</strong>和一个<strong>测试函数</strong>来处理训练和预测的过程。</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
        self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
        self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
        self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
        self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
        self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
        self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
        y = self.conv1(x)
        y = autograd.relu(y)
        y = self.pooling1(y)
        y = self.conv2(y)
        y = autograd.relu(y)
        y = self.pooling2(y)
        y = autograd.flatten(y)
        y = self.linear1(y)
        y = autograd.relu(y)
        y = self.linear2(y)
        <span class="hljs-keyword">return</span> y


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
    y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
    t = np.argmax(target, axis=<span class="hljs-number">1</span>)
    a = y == t
    <span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
          x,
          y,
          epochs=<span class="hljs-number">1</span>,
          batch_size=<span class="hljs-number">64</span>,
          dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size

    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
            l_idx = b * batch_size
            r_idx = (b + <span class="hljs-number">1</span>) * batch_size

            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])

            output_batch = model.forward(x_batch)
            <span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
            <span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>

            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
                                     tensor.to_numpy(target_batch))

            sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
                sgd.update(p, gp)
            sgd.step()

            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
    print(<span class="hljs-string">"training completed"</span>)
    <span class="hljs-keyword">return</span> x_batch, output_batch

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size

    result = <span class="hljs-number">0</span>
    <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
        l_idx = b * batch_size
        r_idx = (b + <span class="hljs-number">1</span>) * batch_size

        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])

        output_batch = model.forward(x_batch)
        result += accuracy(tensor.to_numpy(output_batch),
                           tensor.to_numpy(target_batch))

    print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="训练mnist模型并将其导出到onnx"></a><a href="#训练mnist模型并将其导出到onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>训练mnist模型并将其导出到onnx</h3>
<p>现在，你可以通过调用 <strong>soonx.to_onnx</strong> 函数来训练 mnist 模型并导出其 onnx 模型。</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
    <span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])

<span class="hljs-comment"># create device</span>
dev = device.create_cuda_gpu()
<span class="hljs-comment">#dev = device.get_default_device()</span>
<span class="hljs-comment"># create model</span>
model = CNN()
<span class="hljs-comment"># load data</span>
train_x, train_y, valid_x, valid_y = load_dataset()
<span class="hljs-comment"># normalization</span>
train_x = train_x / <span class="hljs-number">255</span>
valid_x = valid_x / <span class="hljs-number">255</span>
train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
<span class="hljs-comment"># do training</span>
autograd.training = <span class="hljs-literal">True</span>
x, y = train(model, train_x, train_y, dev=dev)
onnx_model = make_onnx(x, y)
<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>

<span class="hljs-comment"># Save the ONNX model</span>
model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
onnx.save(onnx_model, model_path)
print(<span class="hljs-string">'The model is saved.'</span>)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
<p>导出onnx模型后，可以在'/tmp'目录下找到一个名为<strong>mnist.onnx</strong>的文件，这个模型可以被其他库导入。现在，如果你想把这个onnx模型再次导入到singa中，并使用验证数据集进行推理，你可以定义一个叫做<strong>Infer</strong>的类，Infer的前向函数将被测试函数调用，对验证数据集进行推理。此外，你应该把训练的标签设置为<strong>False</strong>，以固定自变量算子的梯度。</p>
<p>在导入onnx模型时，需要先调用<strong>onnx.load</strong>来加载onnx模型。然后将onnx模型输入到 <strong>soonx.prepare</strong>中进行解析，并启动到一个singa模型(代码中的<strong>sg_ir</strong>)。sg_ir里面包含了一个singa图，然后就可以通过输入到它的run函数中运行一步推理。</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
        self.sg_ir = sg_ir
        <span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
            <span class="hljs-comment"># allow the tensors to be updated</span>
            tens.requires_grad = <span class="hljs-literal">True</span>
            tens.stores_grad= <span class="hljs-literal">True</span>
            sg_ir.tensor_map[idx] = tens

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
        <span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>

<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>

<span class="hljs-comment"># inference</span>
autograd.training = <span class="hljs-literal">False</span>
print(<span class="hljs-string">'The inference result is:'</span>)
test(Infer(sg_ir), valid_x, valid_y, dev=dev)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="重训练"></a><a href="#重训练" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>重训练</h3>
<p>假设导入模型后，想再次对模型进行重新训练，我们可以定义一个名为<strong>re_train</strong>的函数。在调用这个re_train函数之前，我们应该将训练的标签设置为<strong>True</strong>，以使自变量运算符更新其梯度。而在完成训练后，我们再将其设置为<strong>False</strong>，以调用做推理的测试函数。</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
             x,
             y,
             epochs=<span class="hljs-number">1</span>,
             batch_size=<span class="hljs-number">64</span>,
             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size

    new_model = Infer(sg_ir)

    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
            l_idx = b * batch_size
            r_idx = (b + <span class="hljs-number">1</span>) * batch_size

            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])

            output_batch = new_model.forward(x_batch)

            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
                                     tensor.to_numpy(target_batch))

            sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
                sgd.update(p, gp)
            sgd.step()

            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
    print(<span class="hljs-string">"re-training completed"</span>)
    <span class="hljs-keyword">return</span> new_model

<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev)

<span class="hljs-comment"># re-training</span>
autograd.training = <span class="hljs-literal">True</span>
new_model = re_train(sg_ir, train_x, train_y, dev=dev)
autograd.training = <span class="hljs-literal">False</span>
test(new_model, valid_x, valid_y, dev=dev)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="迁移学习"></a><a href="#迁移学习" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>迁移学习</h3>
<p>最后，如果我们想做迁移学习，我们可以定义一个名为<strong>Trans</strong>的函数，在onnx模型后追加一些层。为了演示，代码中只在onnx模型后追加了几个线性（全连接）和relu。可以定义一个transfer_learning函数来处理transfer-learning模型的训练过程，而训练的标签和前面一个的一样。</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
        self.sg_ir = sg_ir
        self.last_layers = last_layers
        self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
        self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
        self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
        y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
        y = self.append_linear1(y)
        y = autograd.relu(y)
        y = self.append_linear2(y)
        y = autograd.relu(y)
        y = self.append_linear3(y)
        y = autograd.relu(y)
        <span class="hljs-keyword">return</span> y

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
             x,
             y,
             epochs=<span class="hljs-number">1</span>,
             batch_size=<span class="hljs-number">64</span>,
             dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
    batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size

    trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)

    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
        <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
            l_idx = b * batch_size
            r_idx = (b + <span class="hljs-number">1</span>) * batch_size

            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
            output_batch = trans_model.forward(x_batch)

            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
                                     tensor.to_numpy(target_batch))

            sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
            <span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
                sgd.update(p, gp)
            sgd.step()

            <span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
                print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
                      (accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
    print(<span class="hljs-string">"transfer-learning completed"</span>)
    <span class="hljs-keyword">return</span> trans_mode

<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev)

<span class="hljs-comment"># transfer-learning</span>
autograd.training = <span class="hljs-literal">True</span>
new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
autograd.training = <span class="hljs-literal">False</span>
test(new_model, valid_x, valid_y, dev=dev)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="onnx模型库"></a><a href="#onnx模型库" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>ONNX模型库</h2>
<p><a href="https://github.com/onnx/models">ONNX 模型库</a>是由社区成员贡献的 ONNX 格式的预先训练的最先进模型的集合。SINGA 现在已经支持了几个 CV 和 NLP 模型。将来会支持更多模型。</p>
<h3><a class="anchor" aria-hidden="true" id="图像分类"></a><a href="#图像分类" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>图像分类</h3>
<p>这套模型以图像作为输入，然后将图像中的主要物体分为1000个物体类别，如键盘、鼠标、铅笔和许多动物。</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/mobilenet">MobileNet</a></b></td><td><a href="https://arxiv.org/abs/1801.04381">Sandler et al.</a></td><td>最适合移动和嵌入式视觉应用的轻量级深度神经网络。 <br>Top-5 error from paper - ~10%</td><td><a href="https://colab.research.google.com/drive/1HsixqJMIpKyEPhkbB8jy7NwNEFEAUWAf"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/resnet">ResNet18</a></b></td><td><a href="https://arxiv.org/abs/1512.03385">He et al.</a></td><td>一个CNN模型（多达152层），在对图像进行分类时，使用shortcut来实现更高的准确性。 <br> Top-5 error from paper - ~3.6%</td><td><a href="https://colab.research.google.com/drive/1u1RYefSsVbiP4I-5wiBKHjsT9L0FxLm9"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/vgg">VGG16</a></b></td><td><a href="https://arxiv.org/abs/1409.1556">Simonyan et al.</a></td><td>深度CNN模型（多达19层）。类似于AlexNet，但使用多个较小的内核大小的滤波器，在分类图像时提供更高的准确性。 <br>Top-5 error from paper - ~8%</td><td><a href="https://colab.research.google.com/drive/14kxgRKtbjPCKKsDJVNi3AvTev81Gp_Ds"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/classification/shufflenet">ShuffleNet_V2</a></b></td><td><a href="https://arxiv.org/pdf/1707.01083.pdf">Simonyan et al.</a></td><td>专门为移动设备设计的计算效率极高的CNN模型。这种网络架构设计考虑了速度等直接指标，而不是FLOP等间接指标。 Top-1 error from paper - ~30.6%</td><td><a href="https://colab.research.google.com/drive/19HfRu3YHP_H2z3BcZujVFRp23_J5XsuA?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h3><a class="anchor" aria-hidden="true" id="目标检测"></a><a href="#目标检测" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>目标检测</h3>
<p>目标检测模型可以检测图像中是否存在多个对象，并将图像中检测到对象的区域分割出来。</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny_yolov2">Tiny YOLOv2</a></b></td><td><a href="https://arxiv.org/pdf/1612.08242.pdf">Redmon et al.</a></td><td>一个用于目标检测的实时CNN，可以检测20个不同的类。一个更复杂的完整YOLOv2网络的小版本。</td><td><a href="https://colab.research.google.com/drive/11V4I6cRjIJNUv5ZGsEGwqHuoQEie6b1T"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h3><a class="anchor" aria-hidden="true" id="面部识别"></a><a href="#面部识别" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>面部识别</h3>
<p>人脸检测模型可以识别和/或识别给定图像中的人脸和情绪。</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/body_analysis/arcface">ArcFace</a></b></td><td><a href="https://arxiv.org/abs/1801.07698">Deng et al.</a></td><td>一种基于CNN的人脸识别模型，它可以学习人脸的判别特征，并对输入的人脸图像进行分析。</td><td><a href="https://colab.research.google.com/drive/1qanaqUKGIDtifdzEzJOHjEj4kYzA9uJC"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/vision/body_analysis/emotion_ferplus">Emotion FerPlus</a></b></td><td><a href="https://arxiv.org/abs/1608.01041">Barsoum et al.</a></td><td>基于人脸图像训练的情感识别深度CNN。</td><td><a href="https://colab.research.google.com/drive/1XHtBQGRhe58PDi4LGYJzYueWBeWbO23r"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h3><a class="anchor" aria-hidden="true" id="机器理解"></a><a href="#机器理解" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>机器理解</h3>
<p>这个自然语言处理模型的子集，可以回答关于给定上下文段落的问题。</p>
<table>
<thead>
<tr><th>Model Class</th><th>Reference</th><th>Description</th><th>Link</th></tr>
</thead>
<tbody>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad">BERT-Squad</a></b></td><td><a href="https://arxiv.org/pdf/1810.04805.pdf">Devlin et al.</a></td><td>该模型根据给定输入段落的上下文回答问题。</td><td><a href="https://colab.research.google.com/drive/1kud-lUPjS_u-TkDAzihBTw0Vqr0FjCE-"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/text/machine_comprehension/roberta">RoBERTa</a></b></td><td><a href="https://arxiv.org/pdf/1907.11692.pdf">Devlin et al.</a></td><td>一个基于大型变换器的模型，根据给定的输入文本预测情感。</td><td><a href="https://colab.research.google.com/drive/1F-c4LJSx3Cb2jW6tP7f8nAZDigyLH6iN?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
<tr><td><b><a href="https://github.com/onnx/models/tree/master/text/machine_comprehension/gpt-2">GPT-2</a></b></td><td><a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">Devlin et al.</a></td><td>一个基于大型变换器的语言模型，给定一些文本中的单词序列，预测下一个单词。</td><td><a href="https://colab.research.google.com/drive/1ZlXLSIMppPch6HgzKRillJiUcWn3PiK7?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a></td></tr>
</tbody>
</table>
<h2><a class="anchor" aria-hidden="true" id="支持的操作符"></a><a href="#支持的操作符" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>支持的操作符</h2>
<p>onnx支持下列运算:</p>
<ul>
<li>Acos</li>
<li>Acosh</li>
<li>Add</li>
<li>And</li>
<li>Asin</li>
<li>Asinh</li>
<li>Atan</li>
<li>Atanh</li>
<li>AveragePool</li>
<li>BatchNormalization</li>
<li>Cast</li>
<li>Ceil</li>
<li>Clip</li>
<li>Concat</li>
<li>ConstantOfShape</li>
<li>Conv</li>
<li>Cos</li>
<li>Cosh</li>
<li>Div</li>
<li>Dropout</li>
<li>Elu</li>
<li>Equal</li>
<li>Erf</li>
<li>Expand</li>
<li>Flatten</li>
<li>Gather</li>
<li>Gemm</li>
<li>GlobalAveragePool</li>
<li>Greater</li>
<li>HardSigmoid</li>
<li>Identity</li>
<li>LeakyRelu</li>
<li>Less</li>
<li>Log</li>
<li>MatMul</li>
<li>Max</li>
<li>MaxPool</li>
<li>Mean</li>
<li>Min</li>
<li>Mul</li>
<li>Neg</li>
<li>NonZero</li>
<li>Not</li>
<li>OneHot</li>
<li>Or</li>
<li>Pad</li>
<li>Pow</li>
<li>PRelu</li>
<li>Reciprocal</li>
<li>ReduceMean</li>
<li>ReduceSum</li>
<li>Relu</li>
<li>Reshape</li>
<li>ScatterElements</li>
<li>Selu</li>
<li>Shape</li>
<li>Sigmoid</li>
<li>Sign</li>
<li>Sin</li>
<li>Sinh</li>
<li>Slice</li>
<li>Softmax</li>
<li>Softplus</li>
<li>Softsign</li>
<li>Split</li>
<li>Sqrt</li>
<li>Squeeze</li>
<li>Sub</li>
<li>Sum</li>
<li>Tan</li>
<li>Tanh</li>
<li>Tile</li>
<li>Transpose</li>
<li>Unsqueeze</li>
<li>Upsample</li>
<li>Where</li>
<li>Xor</li>
</ul>
<h3><a class="anchor" aria-hidden="true" id="对onnx后端的特别说明"></a><a href="#对onnx后端的特别说明" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>对ONNX后端的特别说明</h3>
<ul>
<li><p>Conv, MaxPool 以及 AveragePool</p>
<p>输入必须是1d<code>(N*C*H)</code>和2d<code>(N*C*H*W)</code>的形状，<code>dilation</code>必须是1。</p></li>
<li><p>BatchNormalization</p>
<p><code>epsilon</code> 设定为1e-05，不能改变</p></li>
<li><p>Cast</p>
<p>只支持float32和int32，其他类型都会转向这两种类型。</p></li>
<li><p>Squeeze and Unsqueeze</p>
<p>如果你在<code>Tensor</code>和Scalar之间<code>Squeeze</code>或<code>Unsqueeze</code>时遇到错误，请向我们报告。</p></li>
<li><p>Empty tensor</p>
<p>空张量在SINGA是非法的。</p></li>
</ul>
<h2><a class="anchor" aria-hidden="true" id="实现"></a><a href="#实现" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>实现</h2>
<p>SINGA ONNX的代码在<code>python/singa/soonx.py</code>中，主要有三个类，<code>SingaFrontend</code>、<code>SingaBackend</code>和<code>SingaRep</code>。<code>SingaFrontend</code>将SINGA模型翻译成ONNX模型；<code>SingaBackend</code>将ONNX模型翻译成<code>SingaRep</code>对象，其中存储了所有的SINGA运算符和张量（本文档中的张量指SINGA Tensor）；<code>SingaRep</code>可以像SINGA模型一样运行。</p>
<h3><a class="anchor" aria-hidden="true" id="singafrontend"></a><a href="#singafrontend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaFrontend</h3>
<p><code>SingaFrontend</code>的入口函数是<code>singa_to_onnx_model</code>，它也被称为<code>to_onnx</code>，<code>singa_to_onnx_model</code>创建了ONNX模型，它还通过<code>singa_to_onnx_graph</code>创建了一个ONNX图。</p>
<p><code>singa_to_onnx_graph</code>接受模型的输出，并从输出中递归迭代SINGA模型的图，得到所有的运算符，形成一个队列。SINGA模型的输入和中间张量，即可训练的权重，同时被获取。输入存储在<code>onnx_model.graph.input</code>中；输出存储在<code>onnx_model.graph.output</code>中；可训练权重存储在<code>onnx_model.graph.initializer</code>中。</p>
<p>然后将队列中的SINGA运算符逐一翻译成ONNX运算符。<code>_rename_operators</code> 定义了 SINGA 和 ONNX 之间的运算符名称映射。<code>_special_operators</code> 定义了翻译运算符时要使用的函数。</p>
<p>此外，SINGA 中的某些运算符与 ONNX 的定义不同，即 ONNX 将 SINGA 运算符的某些属性视为输入，因此 <code>_unhandled_operators</code> 定义了处理特殊运算符的函数。</p>
<p>由于SINGA中的布尔类型被视为int32，所以<code>_bool_operators</code>定义了要改变的操作符为布尔类型。</p>
<h3><a class="anchor" aria-hidden="true" id="singabackend"></a><a href="#singabackend" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaBackend</h3>
<p><code>SingaBackend</code>的入口函数是<code>prepare</code>，它检查ONNX模型的版本，然后调用<code>_onnx_model_to_singa_net</code>。</p>
<p><code>_onnx_model_to_singa_net</code>的目的是获取SINGA的时序和运算符。tensors在ONNX中以其名称存储在字典中，而操作符则以<code>namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])</code>的形式存储在队列中。对于每个运算符，<code>name</code>是它的ONNX节点名称；<code>op</code>是ONNX节点；<code>forward</code>是SINGA运算符的转发函数；<code>handle</code>是为一些特殊的运算符准备的，如Conv和Pooling，它们有<code>handle</code>对象。</p>
<p><code>_onnx_model_to_singa_net</code>的第一步是调用<code>_init_graph_parameter</code>来获取模型内的所有tensors。对于可训练的权重，可以从<code>onnx_model.graph.initializer</code>中初始化<code>SINGA Tensor</code>。请注意，权重也可能存储在图的输入或称为<code>Constant</code>的ONNX节点中，SINGA也可以处理这些。</p>
<p>虽然所有的权重都存储在ONNX模型中，但模型的输入是未知的，只有它的形状和类型。所以SINGA支持两种方式来初始化输入，1、根据其形状和类型生成随机张量，2、允许用户分配输入。第一种方法对大多数模型都很好，但是对于一些模型，比如BERT，矩阵的指数不能随机生成，否则会产生错误。</p>
<p>然后，<code>_onnx_model_to_singa_net</code>迭代ONNX图中的所有节点，将其翻译成SIGNA运算符。另外，<code>_rename_operators</code> 定义了 SINGA 和 ONNX 之间的运算符名称映射。<code>_special_operators</code> 定义翻译运算符时要使用的函数。<code>_run_node</code>通过输入时序来运行生成的 SINGA 模型，并存储其输出时序，供以后的运算符使用。</p>
<p>该类最后返回一个<code>SingaRep</code>对象，并在其中存储所有SINGA时序和运算符。</p>
<h3><a class="anchor" aria-hidden="true" id="singarep"></a><a href="#singarep" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>SingaRep</h3>
<p><code>SingaBackend</code>存储所有的SINGA tensors和运算符。<code>run</code>接受模型的输入，并按照运算符队列逐个运行SINGA运算符。用户可以使用<code>last_layers</code>来决定是否将模型运行到最后几层。将 <code>all_outputs</code> 设置为 <code>False</code> 表示只得到最后的输出，设置为 <code>True</code> 表示也得到所有的中间输出。</p>
</span></div></article></div><div class="docLastUpdate"><em>Last updated on 20/09/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/3.1.0_Chinese/graph"><span class="arrow-prev">← </span><span>Model</span></a><a class="docs-next button" href="/docs/3.1.0_Chinese/dist-train"><span>Distributed Training</span><span class="arrow-next"> →</span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#通常用法">通常用法</a><ul class="toc-headings"><li><a href="#从onnx中读取一个model到singa">从ONNX中读取一个Model到SINGA</a></li><li><a href="#inference-singa模型">Inference SINGA模型</a></li><li><a href="#将singa模型保存成onnx格式">将SINGA模型保存成ONNX格式</a></li><li><a href="#在onnx模型上重新训练">在ONNX模型上重新训练</a></li><li><a href="#在onnx模型上做迁移学习">在ONNX模型上做迁移学习</a></li></ul></li><li><a href="#一个完整示例">一个完整示例</a><ul class="toc-headings"><li><a href="#读取数据集">读取数据集</a></li><li><a href="#mnist模型">MNIST模型</a></li><li><a href="#训练mnist模型并将其导出到onnx">训练mnist模型并将其导出到onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#重训练">重训练</a></li><li><a href="#迁移学习">迁移学习</a></li></ul></li><li><a href="#onnx模型库">ONNX模型库</a><ul class="toc-headings"><li><a href="#图像分类">图像分类</a></li><li><a href="#目标检测">目标检测</a></li><li><a href="#面部识别">面部识别</a></li><li><a href="#机器理解">机器理解</a></li></ul></li><li><a href="#支持的操作符">支持的操作符</a><ul class="toc-headings"><li><a href="#对onnx后端的特别说明">对ONNX后端的特别说明</a></li></ul></li><li><a href="#实现">实现</a><ul class="toc-headings"><li><a href="#singafrontend">SingaFrontend</a></li><li><a href="#singabackend">SingaBackend</a></li><li><a href="#singarep">SingaRep</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2021
   The Apache Software Foundation. All rights reserved.
   Apache SINGA, Apache, the Apache feather logo, and
   the Apache SINGA project logos are trademarks of The
   Apache Software Foundation. All other marks mentioned
   may be trademarks or registered trademarks of their
   respective owners.</section></div></footer></div><script type="text/javascript" src="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.js"></script><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script><script>
                document.addEventListener('keyup', function(e) {
                  if (e.target !== document.body) {
                    return;
                  }
                  // keyCode for '/' (slash)
                  if (e.keyCode === 191) {
                    const search = document.getElementById('search_input_react');
                    search && search.focus();
                  }
                });
              </script><script>
              var search = docsearch({
                
                apiKey: '45202133606c0b5fa6d21cddc4725dd8',
                indexName: 'apache_singa',
                inputSelector: '#search_input_react',
                algoliaOptions: {"facetFilters":["language:en","version:3.0.0"]}
              });
            </script></body></html>