blob: 304ce0e743abcb3c041bb0a5e1e93a4708ff75c1 [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>Optimizer · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="next"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="Optimizer · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://singa.apache.org/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.css"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://singa.apache.org/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://singa.apache.org/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>next</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/next/installation" target="_self">Docs</a></li><li class=""><a href="/docs/next/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class="navSearchWrapper reactNavSearchWrapper"><input type="text" id="search_input_react" placeholder="Search" title="Search"/></li><li class=""><a href="https://github.com/apache/singa" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i></i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/next/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/next/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/next/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/next/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/next/tensor">Tensor</a></li><li class="navListItem"><a class="navItem" href="/docs/next/autograd">Autograd</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/next/optimizer">Optimizer</a></li><li class="navListItem"><a class="navItem" href="/docs/next/graph">Model</a></li><li class="navListItem"><a class="navItem" href="/docs/next/onnx">ONNX</a></li><li class="navListItem"><a class="navItem" href="/docs/next/dist-train">Distributed Training</a></li><li class="navListItem"><a class="navItem" href="/docs/next/time-profiling">Time Profiling</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/next/download-singa">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/next/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/next/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/next/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/next/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/next/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script>
var coll = document.getElementsByClassName('collapsible');
var checkActiveCategory = true;
for (var i = 0; i < coll.length; i++) {
var links = coll[i].nextElementSibling.getElementsByTagName('*');
if (checkActiveCategory){
for (var j = 0; j < links.length; j++) {
if (links[j].classList.contains('navListItemActive')){
coll[i].nextElementSibling.classList.toggle('hide');
coll[i].childNodes[1].classList.toggle('rotate');
checkActiveCategory = false;
break;
}
}
}
coll[i].addEventListener('click', function() {
var arrow = this.childNodes[1];
arrow.classList.toggle('rotate');
var content = this.nextElementSibling;
content.classList.toggle('hide');
});
}
document.addEventListener('DOMContentLoaded', function() {
createToggler('#navToggler', '#docsNav', 'docsSliderActive');
createToggler('#tocToggler', 'body', 'tocActive');
var headings = document.querySelector('.toc-headings');
headings && headings.addEventListener('click', function(event) {
var el = event.target;
while(el !== headings){
if (el.tagName === 'A') {
document.body.classList.remove('tocActive');
break;
} else{
el = el.parentNode;
}
}
}, false);
function createToggler(togglerSelector, targetSelector, className) {
var toggler = document.querySelector(togglerSelector);
var target = document.querySelector(targetSelector);
if (!toggler) {
return;
}
toggler.onclick = function(event) {
event.preventDefault();
target.classList.toggle(className);
};
}
});
</script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs-site/docs/optimizer.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">Optimizer</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>SINGA supports various popular optimizers including stochastic gradient descent
with momentum, Adam, RMSProp, and AdaGrad, etc. For each of the optimizer, it
supports to use a decay schedular to schedule the learning rate to be applied in
different epochs. The optimizers and the decay schedulers are included in
<code>singa/opt.py</code>.</p>
<h2><a class="anchor" aria-hidden="true" id="create-an-optimizer"></a><a href="#create-an-optimizer" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Create an optimizer</h2>
<ol>
<li>SGD with momentum</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter momentum</span>
momentum = <span class="hljs-number">0.9</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.SGD(lr=lr, momentum=momentum, weight_decay=weight_decay)
</code></pre>
<ol start="2">
<li>RMSProp</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter rho</span>
rho = <span class="hljs-number">0.9</span>
<span class="hljs-comment"># define hyperparameter epsilon</span>
epsilon = <span class="hljs-number">1e-8</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.RMSProp(lr=lr, rho=rho, epsilon=epsilon, weight_decay=weight_decay)
</code></pre>
<ol start="3">
<li>AdaGrad</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter epsilon</span>
epsilon = <span class="hljs-number">1e-8</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.AdaGrad(lr=lr, epsilon=epsilon, weight_decay=weight_decay)
</code></pre>
<ol start="4">
<li>Adam</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-comment"># define hyperparameter learning rate</span>
lr = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define hyperparameter beta 1</span>
beta_1= <span class="hljs-number">0.9</span>
<span class="hljs-comment"># define hyperparameter beta 2</span>
beta_1= <span class="hljs-number">0.999</span>
<span class="hljs-comment"># define hyperparameter epsilon</span>
epsilon = <span class="hljs-number">1e-8</span>
<span class="hljs-comment"># define hyperparameter weight decay</span>
weight_decay = <span class="hljs-number">0.0001</span>
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
sgd = opt.Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, weight_decay=weight_decay)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="create-a-decay-scheduler"></a><a href="#create-a-decay-scheduler" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Create a Decay Scheduler</h2>
<pre><code class="hljs css language-python"><span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
<span class="hljs-comment"># define initial learning rate</span>
lr_init = <span class="hljs-number">0.001</span>
<span class="hljs-comment"># define the rate of decay in the decay scheduler</span>
decay_rate = <span class="hljs-number">0.95</span>
<span class="hljs-comment"># define whether the learning rate schedule is a staircase shape</span>
staircase=<span class="hljs-literal">True</span>
<span class="hljs-comment"># define the decay step of the decay scheduler (in this example the lr is decreased at every 2 steps)</span>
decay_steps = <span class="hljs-number">2</span>
<span class="hljs-comment"># create the decay scheduler, the schedule of lr becomes lr_init * (decay_rate ^ (step // decay_steps) )</span>
lr = opt.ExponentialDecay(<span class="hljs-number">0.1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0.5</span>, <span class="hljs-literal">True</span>)
<span class="hljs-comment"># Use the lr to create an optimizer</span>
sgd = opt.SGD(lr=lr, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">0.0001</span>)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="use-the-optimizer-in-model-api"></a><a href="#use-the-optimizer-in-model-api" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Use the optimizer in Model API</h2>
<p>When we create the model, we need to attach the optimizer to the model.</p>
<pre><code class="hljs css language-python"><span class="hljs-comment"># create a CNN using the Model API</span>
model = CNN()
<span class="hljs-comment"># initialize optimizer and attach it to the model</span>
sgd = opt.SGD(lr=<span class="hljs-number">0.005</span>, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">1e-5</span>)
model.set_optimizer(sgd)
</code></pre>
<p>Then, when we call the model, it runs the <code>train_one_batch</code> method that utilizes
the optimizer.</p>
<p>Hence, an example of an iterative loop to optimize the model is:</p>
<pre><code class="hljs css language-python"><span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(num_train_batch):
<span class="hljs-comment"># generate the next mini-batch</span>
x, y = ...
<span class="hljs-comment"># Copy the data into input tensors</span>
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
<span class="hljs-comment"># Training with one batch</span>
out, loss = model(tx, ty)
</code></pre>
</span></div></article></div><div class="docLastUpdate"><em>Last updated on 4/7/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/next/autograd"><span class="arrow-prev"></span><span>Autograd</span></a><a class="docs-next button" href="/docs/next/graph"><span>Model</span><span class="arrow-next"></span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#create-an-optimizer">Create an optimizer</a></li><li><a href="#create-a-decay-scheduler">Create a Decay Scheduler</a></li><li><a href="#use-the-optimizer-in-model-api">Use the optimizer in Model API</a></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script type="text/javascript" src="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.js"></script><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script><script>
document.addEventListener('keyup', function(e) {
if (e.target !== document.body) {
return;
}
// keyCode for '/' (slash)
if (e.keyCode === 191) {
const search = document.getElementById('search_input_react');
search && search.focus();
}
});
</script><script>
var search = docsearch({
apiKey: '45202133606c0b5fa6d21cddc4725dd8',
indexName: 'apache_singa',
inputSelector: '#search_input_react',
algoliaOptions: {"facetFilters":["language:en","version:3.0.0"]}
});
</script></body></html>