blob: 73288056e9afec34ea3b8a23c038c99f332aa68f [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>Optimizer · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="4.0.0_Chinese"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="Optimizer · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://singa.apache.org/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.css"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://singa.apache.org/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://singa.apache.org/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Source+Sans+Pro:400,400i,700"/><link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Baloo+Paaji+2&amp;family=Source+Sans+Pro:wght@200;300&amp;display=swap"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>4.0.0_Chinese</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/4.0.0_Chinese/installation" target="_self">Docs</a></li><li class=""><a href="/docs/4.0.0_Chinese/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class="navSearchWrapper reactNavSearchWrapper"><input type="text" id="search_input_react" placeholder="Search" title="Search"/></li><li class=""><a href="https://github.com/apache/singa" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i></i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/tensor">Tensor</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/autograd">Autograd</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/4.0.0_Chinese/optimizer">Optimizer</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/graph">Model</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/onnx">ONNX</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/dist-train">Distributed Training</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/time-profiling">Time Profiling</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/half-precision">Half Precision</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/downloads">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/4.0.0_Chinese/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script>
var coll = document.getElementsByClassName('collapsible');
var checkActiveCategory = true;
for (var i = 0; i < coll.length; i++) {
var links = coll[i].nextElementSibling.getElementsByTagName('*');
if (checkActiveCategory){
for (var j = 0; j < links.length; j++) {
if (links[j].classList.contains('navListItemActive')){
coll[i].nextElementSibling.classList.toggle('hide');
coll[i].childNodes[1].classList.toggle('rotate');
checkActiveCategory = false;
break;
}
}
}
coll[i].addEventListener('click', function() {
var arrow = this.childNodes[1];
arrow.classList.toggle('rotate');
var content = this.nextElementSibling;
content.classList.toggle('hide');
});
}
document.addEventListener('DOMContentLoaded', function() {
createToggler('#navToggler', '#docsNav', 'docsSliderActive');
createToggler('#tocToggler', 'body', 'tocActive');
var headings = document.querySelector('.toc-headings');
headings && headings.addEventListener('click', function(event) {
var el = event.target;
while(el !== headings){
if (el.tagName === 'A') {
document.body.classList.remove('tocActive');
break;
} else{
el = el.parentNode;
}
}
}, false);
function createToggler(togglerSelector, targetSelector, className) {
var toggler = document.querySelector(togglerSelector);
var target = document.querySelector(targetSelector);
if (!toggler) {
return;
}
toggler.onclick = function(event) {
event.preventDefault();
target.classList.toggle(className);
};
}
});
</script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs-site/docs/optimizer.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">Optimizer</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>SINGA支持各种流行的优化器,包括动量随机梯度下降、Adam、RMSProp和AdaGrad等。对于每一种优化器,它都支持使用衰减调度器来安排不同时间段的学习率。优化器和衰减调度器包含在<code>singa/opt.py</code>中。</p>
<h2><a class="anchor" aria-hidden="true" id="创建一个优化器"></a><a href="#创建一个优化器" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>创建一个优化器</h2>
<ol>
<li>带动量的SGD</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter momentum</span>
momentum = <span class="hljs-number">0.9</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.SGD(lr=lr, momentum=momentum, weight_decay=weight_decay)
</code></pre>
<ol start="2">
<li>RMSProp</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter rho</span>
rho = <span class="hljs-number">0.9</span>
<span class="hljs-comment"># define hyperparameter epsilon</span>
epsilon = <span class="hljs-number">1e-8</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.RMSProp(lr=lr, rho=rho, epsilon=epsilon, weight_decay=weight_decay)
</code></pre>
<ol start="3">
<li>AdaGrad</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter epsilon</span>
epsilon = <span class="hljs-number">1e-8</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.AdaGrad(lr=lr, epsilon=epsilon, weight_decay=weight_decay)
</code></pre>
<ol start="4">
<li>Adam</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter beta 1</span>
beta_1= <span class="hljs-number">0.9</span>
<span class="hljs-comment"># define hyperparameter beta 2</span>
beta_1= <span class="hljs-number">0.999</span>
<span class="hljs-comment"># define hyperparameter epsilon</span>
epsilon = <span class="hljs-number">1e-8</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, weight_decay=weight_decay)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="创建一个衰减调度器"></a><a href="#创建一个衰减调度器" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>创建一个衰减调度器</h2>
<pre><code class="hljs css language-python"><span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
<span class="hljs-comment"># define initial learning rate</span>
lr_init = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define the rate of decay in the decay scheduler</span>
decay_rate = <span class="hljs-number">0.95</span>
<span class="hljs-comment"># define whether the learning rate schedule is a staircase shape</span>
staircase=<span class="hljs-literal">True</span>
<span class="hljs-comment"># define the decay step of the decay scheduler (in this example the lr is decreased at every 2 steps)</span>
decay_steps = <span class="hljs-number">2</span>
<span class="hljs-comment"># create the decay scheduler, the schedule of lr becomes lr_init * (decay_rate ^ (step // decay_steps) )</span>
lr = opt.ExponentialDecay(<span class="hljs-number">0.1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0.5</span>, <span class="hljs-literal">True</span>)
<span class="hljs-comment"># Use the lr to create an optimizer</span>
sgd = opt.SGD(lr=lr, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">0.0001</span>)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="使用模型api中的优化器"></a><a href="#使用模型api中的优化器" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>使用模型API中的优化器</h2>
<p>当我们创建模型时,我们需要将优化器附加到模型上:</p>
<pre><code class="hljs css language-python"><span class="hljs-comment"># create a CNN using the Model API</span>
model = CNN()
<span class="hljs-comment"># initialize optimizer and attach it to the model</span>
sgd = opt.SGD(lr=<span class="hljs-number">0.005</span>, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">1e-5</span>)
model.set_optimizer(sgd)
</code></pre>
<p>然后,当我们调用模型时,它会运行利用优化器的 <code>train_one_batch</code> 方法。</p>
<p>因此,一个迭代循环优化模型的例子是:</p>
<pre><code class="hljs css language-python"><span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(num_train_batch):
<span class="hljs-comment"># generate the next mini-batch</span>
x, y = ...
<span class="hljs-comment"># Copy the data into input tensors</span>
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
<span class="hljs-comment"># Training with one batch</span>
out, loss = model(tx, ty)
</code></pre>
</span></div></article></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/4.0.0_Chinese/autograd"><span class="arrow-prev"></span><span>Autograd</span></a><a class="docs-next button" href="/docs/4.0.0_Chinese/graph"><span>Model</span><span class="arrow-next"></span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#创建一个优化器">创建一个优化器</a></li><li><a href="#创建一个衰减调度器">创建一个衰减调度器</a></li><li><a href="#使用模型api中的优化器">使用模型API中的优化器</a></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2023
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script type="text/javascript" src="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.js"></script><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script><script>
document.addEventListener('keyup', function(e) {
if (e.target !== document.body) {
return;
}
// keyCode for '/' (slash)
if (e.keyCode === 191) {
const search = document.getElementById('search_input_react');
search && search.focus();
}
});
</script><script>
var search = docsearch({
apiKey: '45202133606c0b5fa6d21cddc4725dd8',
indexName: 'apache_singa',
inputSelector: '#search_input_react',
algoliaOptions: {"facetFilters":["language:en","version:3.0.0"]}
});
</script></body></html>