blob: 2a9d8241833a5b65939df5d852c3d20a235867a9 [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>Autograd · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="4.0.0_Chinese"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="Autograd · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://singa.apache.org/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.css"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://singa.apache.org/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://singa.apache.org/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Source+Sans+Pro:400,400i,700"/><link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Baloo+Paaji+2&amp;family=Source+Sans+Pro:wght@200;300&amp;display=swap"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>4.0.0_Chinese</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/4.0.0_Chinese/installation" target="_self">Docs</a></li><li class=""><a href="/docs/4.0.0_Chinese/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class="navSearchWrapper reactNavSearchWrapper"><input type="text" id="search_input_react" placeholder="Search" title="Search"/></li><li class=""><a href="https://github.com/apache/singa" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i></i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/tensor">Tensor</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/4.0.0_Chinese/autograd">Autograd</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/optimizer">Optimizer</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/graph">Model</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/onnx">ONNX</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/dist-train">Distributed Training</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/time-profiling">Time Profiling</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/half-precision">Half Precision</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/downloads">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script>
var coll = document.getElementsByClassName('collapsible');
var checkActiveCategory = true;
for (var i = 0; i < coll.length; i++) {
var links = coll[i].nextElementSibling.getElementsByTagName('*');
if (checkActiveCategory){
for (var j = 0; j < links.length; j++) {
if (links[j].classList.contains('navListItemActive')){
coll[i].nextElementSibling.classList.toggle('hide');
coll[i].childNodes[1].classList.toggle('rotate');
checkActiveCategory = false;
break;
}
}
}
coll[i].addEventListener('click', function() {
var arrow = this.childNodes[1];
arrow.classList.toggle('rotate');
var content = this.nextElementSibling;
content.classList.toggle('hide');
});
}
document.addEventListener('DOMContentLoaded', function() {
createToggler('#navToggler', '#docsNav', 'docsSliderActive');
createToggler('#tocToggler', 'body', 'tocActive');
var headings = document.querySelector('.toc-headings');
headings && headings.addEventListener('click', function(event) {
var el = event.target;
while(el !== headings){
if (el.tagName === 'A') {
document.body.classList.remove('tocActive');
break;
} else{
el = el.parentNode;
}
}
}, false);
function createToggler(togglerSelector, targetSelector, className) {
var toggler = document.querySelector(togglerSelector);
var target = document.querySelector(targetSelector);
if (!toggler) {
return;
}
toggler.onclick = function(event) {
event.preventDefault();
target.classList.toggle(className);
};
}
});
</script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs-site/docs/autograd.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">Autograd</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>实现autograd有两种典型的方式,一种是通过如<a href="http://deeplearning.net/software/theano/index.html">Theano</a>的符号微分(symbolic differentiation)或通过如<a href="https://pytorch.org/docs/stable/notes/autograd.html">Pytorch</a>的反向微分(reverse differentialtion)。SINGA遵循Pytorch方式,即通过记录计算图,并在正向传播后自动应用反向传播。自动传播算法的详细解释请参阅<a href="https://pytorch.org/docs/stable/notes/autograd.html">这里</a>。我们接下来对SINGA中的相关模块进行解释,并举例说明其使用方法。</p>
<h2><a class="anchor" aria-hidden="true" id="相关模块"></a><a href="#相关模块" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>相关模块</h2>
<p>在autograd中涉及三个类,分别是<code>singa.tensor.Tensor</code><code>singa.autograd.Operation</code><code>singa.autograd.Layer</code>。在本篇的后续部分中,我们使用Tensor、Operation和Layer来指代这三个类。</p>
<h3><a class="anchor" aria-hidden="true" id="tensor"></a><a href="#tensor" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Tensor</h3>
<p>Tensor的三个属性被autograd使用:</p>
<ul>
<li><code>.creator</code>是一个<code>Operation</code>实例。它记录了产生Tensor实例的这个操作。</li>
<li><code>.request_grad</code>是一个布尔变量。它用于指示autograd算法是否需要计算张量的梯度。例如,在反向传播的过程中,线性层的权重矩阵和卷积层(非底层)的特征图的张量梯度应该被计算。</li>
<li><code>.store_grad</code>是一个布尔变量。它用于指示张量的梯度是否应该被存储并由后向函数输出。例如,特征图的梯度是在反向传播过程中计算出来的,但不包括在反向函数的输出中。</li>
</ul>
<p>开发者可以改变Tensor实例的<code>requires_grad</code><code>stores_grad</code>。例如,如果将后者设置为True,那么相应的梯度就会被包含在后向函数的输出。需要注意的是,如果<code>stores_grad</code>是True,那么 <code>requires_grad</code>一定是真,反之亦然。</p>
<h3><a class="anchor" aria-hidden="true" id="operation"></a><a href="#operation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Operation</h3>
<p>它将一个或多个<code>Tensor</code>实例作为输入,然后输出一个或多个<code>Tensor</code>实例。例如,ReLU可以作为一个具体的Operation子类来实现。当一个<code>Operation</code>实例被调用时(实例化后),会执行以下两个步骤。</p>
<p>1.记录源操作,即输入张量的<code>创建者</code>
2.通过调用成员函数<code>.forward()</code>进行计算。</p>
<p>有两个成员函数用于前向和反向传播,即<code>.forward()</code><code>.backward()</code>。它们以<code>Tensor.data</code>作为输入(类型为<code>CTensor</code>),并输出<code>Ctensor</code>s。要添加一个特定的操作,子类<code>Operation</code>应该实现自己的<code>.forward()</code><code>.backward()</code>函数。在后向传播过程中,autograd的<code>backward()</code>函数会自动调用<code>backward()</code>函数来计算输入的梯度(根据<code>require_grad</code>字段的参数和约束)。</p>
<h3><a class="anchor" aria-hidden="true" id="layer"></a><a href="#layer" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Layer</h3>
<p>对于那些需要参数的Operation,我们把它们封装成一个新的类,<code>Layer</code>。例如,卷积操作被封装到卷积层(Convolution layer)中。<code></code>管理(存储)参数,并调用相应的<code>Operation</code>来实现变换。</p>
<h2><a class="anchor" aria-hidden="true" id="样例"></a><a href="#样例" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>样例</h2>
<p><a href="https://github.com/apache/singa/tree/master/examples/autograd">example folder</a>中提供了很多样例。在这里我我们分析两个最具代表性的例子。</p>
<h3><a class="anchor" aria-hidden="true" id="只使用operation"></a><a href="#只使用operation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>只使用Operation</h3>
<p>下一段代码展示了一个只使用<code>Operation</code>的多层感知机(MLP)模型:</p>
<h4><a class="anchor" aria-hidden="true" id="调用依赖包"></a><a href="#调用依赖包" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>调用依赖包</h4>
<pre><code class="hljs css language-python"><span class="hljs-keyword">from</span> singa.tensor <span class="hljs-keyword">import</span> Tensor
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="创建权重矩阵和偏置向量"></a><a href="#创建权重矩阵和偏置向量" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>创建权重矩阵和偏置向量</h4>
<p>在将<code>requires_grad</code><code>stores_grad</code>都设置为<code>True</code>的情况下,创建参数张量。</p>
<pre><code class="hljs css language-python">w0 = Tensor(shape=(<span class="hljs-number">2</span>, <span class="hljs-number">3</span>), requires_grad=<span class="hljs-literal">True</span>, stores_grad=<span class="hljs-literal">True</span>)
w0.gaussian(<span class="hljs-number">0.0</span>, <span class="hljs-number">0.1</span>)
b0 = Tensor(shape=(<span class="hljs-number">1</span>, <span class="hljs-number">3</span>), requires_grad=<span class="hljs-literal">True</span>, stores_grad=<span class="hljs-literal">True</span>)
b0.set_value(<span class="hljs-number">0.0</span>)
w1 = Tensor(shape=(<span class="hljs-number">3</span>, <span class="hljs-number">2</span>), requires_grad=<span class="hljs-literal">True</span>, stores_grad=<span class="hljs-literal">True</span>)
w1.gaussian(<span class="hljs-number">0.0</span>, <span class="hljs-number">0.1</span>)
b1 = Tensor(shape=(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>), requires_grad=<span class="hljs-literal">True</span>, stores_grad=<span class="hljs-literal">True</span>)
b1.set_value(<span class="hljs-number">0.0</span>)
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="训练"></a><a href="#训练" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>训练</h4>
<pre><code class="hljs css language-python">inputs = Tensor(data=data) <span class="hljs-comment"># data matrix</span>
target = Tensor(data=label) <span class="hljs-comment"># label vector</span>
autograd.training = <span class="hljs-literal">True</span> <span class="hljs-comment"># for training</span>
sgd = opt.SGD(<span class="hljs-number">0.05</span>) <span class="hljs-comment"># optimizer</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">10</span>):
x = autograd.matmul(inputs, w0) <span class="hljs-comment"># matrix multiplication</span>
x = autograd.add_bias(x, b0) <span class="hljs-comment"># add the bias vector</span>
x = autograd.relu(x) <span class="hljs-comment"># ReLU activation operation</span>
x = autograd.matmul(x, w1)
x = autograd.add_bias(x, b1)
loss = autograd.softmax_cross_entropy(x, target)
<span class="hljs-keyword">for</span> p, g <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, g)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="使用operation和layer"></a><a href="#使用operation和layer" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>使用Operation和Layer</h3>
<p>下面的<a href="https://github.com/apache/singa/blob/master/examples/autograd/mnist_cnn.py">例子</a>使用autograd模块提供的层实现了一个CNN模型。</p>
<h4><a class="anchor" aria-hidden="true" id="创建层"></a><a href="#创建层" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>创建层</h4>
<pre><code class="hljs css language-python">conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">32</span>, <span class="hljs-number">3</span>, padding=<span class="hljs-number">1</span>, bias=<span class="hljs-literal">False</span>)
bn1 = autograd.BatchNorm2d(<span class="hljs-number">32</span>)
pooling1 = autograd.MaxPool2d(<span class="hljs-number">3</span>, <span class="hljs-number">1</span>, padding=<span class="hljs-number">1</span>)
conv21 = autograd.Conv2d(<span class="hljs-number">32</span>, <span class="hljs-number">16</span>, <span class="hljs-number">3</span>, padding=<span class="hljs-number">1</span>)
conv22 = autograd.Conv2d(<span class="hljs-number">32</span>, <span class="hljs-number">16</span>, <span class="hljs-number">3</span>, padding=<span class="hljs-number">1</span>)
bn2 = autograd.BatchNorm2d(<span class="hljs-number">32</span>)
linear = autograd.Linear(<span class="hljs-number">32</span> * <span class="hljs-number">28</span> * <span class="hljs-number">28</span>, <span class="hljs-number">10</span>)
pooling2 = autograd.AvgPool2d(<span class="hljs-number">3</span>, <span class="hljs-number">1</span>, padding=<span class="hljs-number">1</span>)
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="定义正向传播函数"></a><a href="#定义正向传播函数" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>定义正向传播函数</h4>
<p>在正向传播中的operations会被自动记录,用于反向传播。</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(x, t)</span>:</span>
<span class="hljs-comment"># x is the input data (a batch of images)</span>
<span class="hljs-comment"># t is the label vector (a batch of integers)</span>
y = conv1(x) <span class="hljs-comment"># Conv layer</span>
y = autograd.relu(y) <span class="hljs-comment"># ReLU operation</span>
y = bn1(y) <span class="hljs-comment"># BN layer</span>
y = pooling1(y) <span class="hljs-comment"># Pooling Layer</span>
<span class="hljs-comment"># two parallel convolution layers</span>
y1 = conv21(y)
y2 = conv22(y)
y = autograd.cat((y1, y2), <span class="hljs-number">1</span>) <span class="hljs-comment"># cat operation</span>
y = autograd.relu(y) <span class="hljs-comment"># ReLU operation</span>
y = bn2(y)
y = pooling2(y)
y = autograd.flatten(y) <span class="hljs-comment"># flatten operation</span>
y = linear(y) <span class="hljs-comment"># Linear layer</span>
loss = autograd.softmax_cross_entropy(y, t) <span class="hljs-comment"># operation</span>
<span class="hljs-keyword">return</span> loss, y
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="训练-1"></a><a href="#训练-1" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>训练</h4>
<pre><code class="hljs css language-python">autograd.training = <span class="hljs-literal">True</span>
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(batch_number):
inputs = tensor.Tensor(device=dev, data=x_train[
i * batch_sz:(<span class="hljs-number">1</span> + i) * batch_sz], stores_grad=<span class="hljs-literal">False</span>)
targets = tensor.Tensor(device=dev, data=y_train[
i * batch_sz:(<span class="hljs-number">1</span> + i) * batch_sz], requires_grad=<span class="hljs-literal">False</span>, stores_grad=<span class="hljs-literal">False</span>)
loss, y = forward(inputs, targets) <span class="hljs-comment"># forward the net</span>
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss): <span class="hljs-comment"># auto backward</span>
sgd.update(p, gp)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="using-the-model-api"></a><a href="#using-the-model-api" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Using the Model API</h3>
<p>下面的<a href="https://github.com/apache/singa/blob/master/examples/cnn/model/cnn.py">样例</a>使用<a href="./graph">Model API</a>实现了一个CNN模型。.</p>
<h4><a class="anchor" aria-hidden="true" id="定义model的子类"></a><a href="#定义model的子类" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>定义Model的子类</h4>
<p>定义模型类,它应该是Model的子类。只有这样,在训练阶段使用的所有操作才会形成一个计算图以便进行分析。图中的操作将被按时序规划并有效执行,模型类中也可以包含层。</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MLP</span><span class="hljs-params">(model.Model)</span>:</span> <span class="hljs-comment"># the model is a subclass of Model</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, data_size=<span class="hljs-number">10</span>, perceptron_size=<span class="hljs-number">100</span>, num_classes=<span class="hljs-number">10</span>)</span>:</span>
super(MLP, self).__init__()
<span class="hljs-comment"># init the operators, layers and other objects</span>
self.relu = layer.ReLU()
self.linear1 = layer.Linear(perceptron_size)
self.linear2 = layer.Linear(num_classes)
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, inputs)</span>:</span> <span class="hljs-comment"># define the forward function</span>
y = self.linear1(inputs)
y = self.relu(y)
y = self.linear2(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(self, x, y)</span>:</span>
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)
self.optimizer(loss)
<span class="hljs-keyword">return</span> out, loss
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">set_optimizer</span><span class="hljs-params">(self, optimizer)</span>:</span> <span class="hljs-comment"># attach an optimizer</span>
self.optimizer = optimizer
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="训练-2"></a><a href="#训练-2" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>训练</h4>
<pre><code class="hljs css language-python"><span class="hljs-comment"># create a model instance</span>
model = MLP()
<span class="hljs-comment"># initialize optimizer and attach it to the model</span>
sgd = opt.SGD(lr=<span class="hljs-number">0.005</span>, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">1e-5</span>)
model.set_optimizer(sgd)
<span class="hljs-comment"># input and target placeholders for the model</span>
tx = tensor.Tensor((batch_size, <span class="hljs-number">1</span>, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
<span class="hljs-comment"># compile the model before training</span>
model.compile([tx], is_train=<span class="hljs-literal">True</span>, use_graph=<span class="hljs-literal">True</span>, sequential=<span class="hljs-literal">False</span>)
<span class="hljs-comment"># train the model iteratively</span>
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(num_train_batch):
<span class="hljs-comment"># generate the next mini-batch</span>
x, y = ...
<span class="hljs-comment"># Copy the data into input tensors</span>
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
<span class="hljs-comment"># Training with one batch</span>
out, loss = model(tx, ty)
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="保存模型checkpoint"></a><a href="#保存模型checkpoint" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>保存模型checkpoint</h4>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define the path to save the checkpoint</span>
checkpointpath=<span class="hljs-string">"checkpoint.zip"</span>
<span class="hljs-comment"># save a checkpoint</span>
model.save_states(fpath=checkpointpath)
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="加载模型checkpoint"></a><a href="#加载模型checkpoint" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>加载模型checkpoint</h4>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define the path to load the checkpoint</span>
checkpointpath=<span class="hljs-string">"checkpoint.zip"</span>
<span class="hljs-comment"># load a checkpoint</span>
<span class="hljs-keyword">import</span> os
<span class="hljs-keyword">if</span> os.path.exists(checkpointpath):
model.load_states(fpath=checkpointpath)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="python-api"></a><a href="#python-api" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Python API</h3>
<p>关于Python API的更多细节,请参考<a href="https://singa.readthedocs.io/en/latest/autograd.html#module-singa.autograd">这里</a></p>
</span></div></article></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/4.0.0_Chinese/tensor"><span class="arrow-prev"></span><span>Tensor</span></a><a class="docs-next button" href="/docs/4.0.0_Chinese/optimizer"><span>Optimizer</span><span class="arrow-next"></span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#相关模块">相关模块</a><ul class="toc-headings"><li><a href="#tensor">Tensor</a></li><li><a href="#operation">Operation</a></li><li><a href="#layer">Layer</a></li></ul></li><li><a href="#样例">样例</a><ul class="toc-headings"><li><a href="#只使用operation">只使用Operation</a></li><li><a href="#使用operation和layer">使用Operation和Layer</a></li><li><a href="#using-the-model-api">Using the Model API</a></li><li><a href="#python-api">Python API</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2023
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script type="text/javascript" src="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.js"></script><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script><script>
document.addEventListener('keyup', function(e) {
if (e.target !== document.body) {
return;
}
// keyCode for '/' (slash)
if (e.keyCode === 191) {
const search = document.getElementById('search_input_react');
search && search.focus();
}
});
</script><script>
var search = docsearch({
apiKey: '45202133606c0b5fa6d21cddc4725dd8',
indexName: 'apache_singa',
inputSelector: '#search_input_react',
algoliaOptions: {"facetFilters":["language:en","version:3.0.0"]}
});
</script></body></html>