| <!DOCTYPE html> |
| |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta content="IE=edge" http-equiv="X-UA-Compatible"/> |
| <meta content="width=device-width, initial-scale=1" name="viewport"/> |
| <meta content="Advanced Learning Rate Schedules" property="og:title"> |
| <meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image"> |
| <meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url"> |
| <meta content="Advanced Learning Rate Schedules" property="og:description"/> |
| <title>Advanced Learning Rate Schedules — mxnet documentation</title> |
| <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> |
| <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> |
| <link href="../../_static/basic.css" rel="stylesheet" type="text/css"> |
| <link href="../../_static/pygments.css" rel="stylesheet" type="text/css"> |
| <link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/> |
| <script type="text/javascript"> |
| var DOCUMENTATION_OPTIONS = { |
| URL_ROOT: '../../', |
| VERSION: '', |
| COLLAPSE_INDEX: false, |
| FILE_SUFFIX: '.html', |
| HAS_SOURCE: true, |
| SOURCELINK_SUFFIX: '.txt' |
| }; |
| </script> |
| <script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script> |
| <script src="../../_static/underscore.js" type="text/javascript"></script> |
| <script src="../../_static/searchtools_custom.js" type="text/javascript"></script> |
| <script src="../../_static/doctools.js" type="text/javascript"></script> |
| <script src="../../_static/selectlang.js" type="text/javascript"></script> |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> |
| <script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/1.5.0/searchindex.js"); Search.init();}); </script> |
| <script> |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new |
| Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| |
| </script> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/jquery.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/underscore.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="../../_static/doctools.js"></script> --> |
| <!-- --> |
| <!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> |
| <!-- --> |
| <link href="../../genindex.html" rel="index" title="Index"> |
| <link href="../../search.html" rel="search" title="Search"/> |
| <link href="index.html" rel="up" title="Tutorials"/> |
| <link href="logistic_regression_explained.html" rel="next" title="Logistic regression using Gluon API explained"/> |
| <link href="learning_rate_schedules.html" rel="prev" title="Learning Rate Schedules"/> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/> |
| </link></link></link></meta></meta></meta></head> |
| <body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document"> |
| <div class="content-block"><div class="navbar navbar-fixed-top"> |
| <div class="container" id="navContainer"> |
| <div class="innder" id="header-inner"> |
| <h1 id="logo-wrap"> |
| <a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a> |
| </h1> |
| <nav class="nav-bar" id="main-nav"> |
| <a class="main-nav-link" href="/versions/1.5.0/install/index.html">Install</a> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/1.5.0/tutorials/gluon/gluon.html">About</a></li> |
| <li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li> |
| <li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li> |
| <li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/clojure/index.html">Clojure</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/java/index.html">Java</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/perl/index.html">Perl</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/scala/index.html">Scala</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-docs"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs"> |
| <li><a class="main-nav-link" href="/versions/1.5.0/faq/index.html">FAQ</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/tutorials/index.html">Tutorials</a> |
| <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.5.0/example">Examples</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/architecture/index.html">Architecture</a></li> |
| <li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/model_zoo/index.html">Model Zoo</a></li> |
| <li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li> |
| </li></ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-community"> |
| <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community"> |
| <li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li> |
| <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.5.0">Github</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/community/contribute.html">Contribute</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/community/ecosystem.html">Ecosystem</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/community/powered_by.html">Powered By</a></li> |
| </ul> |
| </span> |
| <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">1.5.0<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav> |
| <script> function getRootPath(){ return "../../" } </script> |
| <div class="burgerIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> |
| <ul class="dropdown-menu" id="burgerMenu"> |
| <li><a href="/versions/1.5.0/install/index.html">Install</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/tutorials/index.html">Tutorials</a></li> |
| <li class="dropdown-submenu dropdown"> |
| <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a> |
| <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/1.5.0/tutorials/gluon/gluon.html">About</a></li> |
| <li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li> |
| <li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li> |
| <li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu"> |
| <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a> |
| <ul class="dropdown-menu"> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/python/index.html">Python</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/c++/index.html">C++</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/clojure/index.html">Clojure</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/java/index.html">Java</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/julia/index.html">Julia</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/perl/index.html">Perl</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/r/index.html">R</a></li> |
| <li><a class="main-nav-link" href="/versions/1.5.0/api/scala/index.html">Scala</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu"> |
| <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a> |
| <ul class="dropdown-menu"> |
| <li><a href="/versions/1.5.0/faq/index.html" tabindex="-1">FAQ</a></li> |
| <li><a href="/versions/1.5.0/tutorials/index.html" tabindex="-1">Tutorials</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/tree/1.5.0/example" tabindex="-1">Examples</a></li> |
| <li><a href="/versions/1.5.0/architecture/index.html" tabindex="-1">Architecture</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li> |
| <li><a href="/versions/1.5.0/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li> |
| <li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li> |
| </ul> |
| </li> |
| <li class="dropdown-submenu dropdown"> |
| <a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a> |
| <ul class="dropdown-menu"> |
| <li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/tree/1.5.0" tabindex="-1">Github</a></li> |
| <li><a href="/versions/1.5.0/community/contribute.html" tabindex="-1">Contribute</a></li> |
| <li><a href="/versions/1.5.0/community/ecosystem.html" tabindex="-1">Ecosystem</a></li> |
| <li><a href="/versions/1.5.0/community/powered_by.html" tabindex="-1">Powered By</a></li> |
| </ul> |
| </li> |
| <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">1.5.0</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul> |
| </div> |
| <div class="plusIcon dropdown"> |
| <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> |
| <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> |
| </div> |
| <div id="search-input-wrap"> |
| <form action="../../search.html" autocomplete="off" class="" method="get" role="search"> |
| <div class="form-group inner-addon left-addon"> |
| <i class="glyphicon glyphicon-search"></i> |
| <input class="form-control" name="q" placeholder="Search" type="text"/> |
| </div> |
| <input name="check_keywords" type="hidden" value="yes"> |
| <input name="area" type="hidden" value="default"/> |
| </input></form> |
| <div id="search-preview"></div> |
| </div> |
| <div id="searchIcon"> |
| <span aria-hidden="true" class="glyphicon glyphicon-search"></span> |
| </div> |
| <!-- <div id="lang-select-wrap"> --> |
| <!-- <label id="lang-select-label"> --> |
| <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> |
| <!-- <span></span> --> |
| <!-- </label> --> |
| <!-- <select id="lang-select"> --> |
| <!-- <option value="en">Eng</option> --> |
| <!-- <option value="zh">中文</option> --> |
| <!-- </select> --> |
| <!-- </div> --> |
| <!-- <a id="mobile-nav-toggle"> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| <span class="mobile-nav-toggle-bar"></span> |
| </a> --> |
| </div> |
| </div> |
| </div> |
| <script type="text/javascript"> |
| $('body').css('background', 'white'); |
| </script> |
| <div class="container"> |
| <div class="row"> |
| <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <ul> |
| <li class="toctree-l1"><a class="reference internal" href="../../api/index.html">MXNet APIs</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">MXNet Architecture</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../community/index.html">MXNet Community</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">MXNet FAQ</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../gluon/index.html">About Gluon</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing MXNet</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html#nvidia-jetson-tx-family">Nvidia Jetson TX family</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../install/index.html#source-download">Source Download</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../../model_zoo/index.html">MXNet Model Zoo</a></li> |
| <li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li> |
| </ul> |
| </div> |
| </div> |
| <div class="content"> |
| <div class="page-tracker"></div> |
| <!--- Licensed to the Apache Software Foundation (ASF) under one --> |
| <!--- or more contributor license agreements. See the NOTICE file --> |
| <!--- distributed with this work for additional information --> |
| <!--- regarding copyright ownership. The ASF licenses this file --> |
| <!--- to you under the Apache License, Version 2.0 (the --> |
| <!--- "License"); you may not use this file except in compliance --> |
| <!--- with the License. You may obtain a copy of the License at --><!--- http://www.apache.org/licenses/LICENSE-2.0 --><!--- Unless required by applicable law or agreed to in writing, --> |
| <!--- software distributed under the License is distributed on an --> |
| <!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> |
| <!--- KIND, either express or implied. See the License for the --> |
| <!--- specific language governing permissions and limitations --> |
| <!--- under the License. --><div class="section" id="advanced-learning-rate-schedules"> |
| <span id="advanced-learning-rate-schedules"></span><h1>Advanced Learning Rate Schedules<a class="headerlink" href="#advanced-learning-rate-schedules" title="Permalink to this headline">¶</a></h1> |
| <p>Given the importance of learning rate and the learning rate schedule for training neural networks, there have been a number of research papers published recently on the subject. Although many practitioners are using simple learning rate schedules such as stepwise decay, research has shown that there are other strategies that work better in most situations. We implement a number of different schedule shapes in this tutorial and introduce cyclical schedules.</p> |
| <p>See the “Learning Rate Schedules” tutorial for a more basic overview of learning rates, and an example of how to use them while training your own models.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span> |
| <span class="kn">import</span> <span class="nn">copy</span> |
| <span class="kn">import</span> <span class="nn">math</span> |
| <span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span> |
| <span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span> |
| <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="kn">as</span> <span class="nn">plt</span> |
| </pre></div> |
| </div> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">plot_schedule</span><span class="p">(</span><span class="n">schedule_fn</span><span class="p">,</span> <span class="n">iterations</span><span class="o">=</span><span class="mi">1500</span><span class="p">):</span> |
| <span class="c1"># Iteration count starting at 1</span> |
| <span class="n">iterations</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">iterations</span><span class="p">)]</span> |
| <span class="n">lrs</span> <span class="o">=</span> <span class="p">[</span><span class="n">schedule_fn</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">iterations</span><span class="p">]</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">iterations</span><span class="p">,</span> <span class="n">lrs</span><span class="p">)</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Iteration"</span><span class="p">)</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">"Learning Rate"</span><span class="p">)</span> |
| <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
| </pre></div> |
| </div> |
| <div class="section" id="custom-schedule-shapes"> |
| <span id="custom-schedule-shapes"></span><h2>Custom Schedule Shapes<a class="headerlink" href="#custom-schedule-shapes" title="Permalink to this headline">¶</a></h2> |
| <div class="section" id="slanted-triangular"> |
| <span id="slanted-triangular"></span><h3>(Slanted) Triangular<a class="headerlink" href="#slanted-triangular" title="Permalink to this headline">¶</a></h3> |
| <p>While trying to push the boundaries of batch size for faster training, <a class="reference external" href="https://arxiv.org/abs/1706.02677">Priya Goyal et al. (2017)</a> found that having a smooth linear warm up in the learning rate at the start of training improved the stability of the optimizer and lead to better solutions. It was found that a smooth increases gave improved performance over stepwise increases.</p> |
| <p>We look at “warm-up” in more detail later in the tutorial, but this could be viewed as a specific case of the <strong>“triangular”</strong> schedule that was proposed by <a class="reference external" href="https://arxiv.org/abs/1506.01186">Leslie N. Smith (2015)</a>. Quite simply, the schedule linearly increases then decreases between a lower and upper bound. Originally it was suggested this schedule be used as part of a cyclical schedule but more recently researchers have been using a single cycle.</p> |
| <p>One adjustment proposed by <a class="reference external" href="https://arxiv.org/abs/1801.06146">Jeremy Howard, Sebastian Ruder (2018)</a> was to change the ratio between the increasing and decreasing stages, instead of the 50:50 split. Changing the increasing fraction (<code class="docutils literal"><span class="pre">inc_fraction!=0.5</span></code>) leads to a <strong>“slanted triangular”</strong> schedule. Using <code class="docutils literal"><span class="pre">inc_fraction<0.5</span></code> tends to give better results.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">TriangularSchedule</span><span class="p">():</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">min_lr</span><span class="p">,</span> <span class="n">max_lr</span><span class="p">,</span> <span class="n">cycle_length</span><span class="p">,</span> <span class="n">inc_fraction</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span> |
| <span class="sd">"""</span> |
| <span class="sd"> min_lr: lower bound for learning rate (float)</span> |
| <span class="sd"> max_lr: upper bound for learning rate (float)</span> |
| <span class="sd"> cycle_length: iterations between start and finish (int)</span> |
| <span class="sd"> inc_fraction: fraction of iterations spent in increasing stage (float)</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> <span class="o">=</span> <span class="n">min_lr</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">max_lr</span> <span class="o">=</span> <span class="n">max_lr</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span> <span class="o">=</span> <span class="n">cycle_length</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">inc_fraction</span> <span class="o">=</span> <span class="n">inc_fraction</span> |
| |
| <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span> |
| <span class="k">if</span> <span class="n">iteration</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">inc_fraction</span><span class="p">:</span> |
| <span class="n">unit_cycle</span> <span class="o">=</span> <span class="n">iteration</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">inc_fraction</span><span class="p">)</span> |
| <span class="k">elif</span> <span class="n">iteration</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span><span class="p">:</span> |
| <span class="n">unit_cycle</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span> <span class="o">-</span> <span class="n">iteration</span><span class="p">)</span> <span class="o">*</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">inc_fraction</span><span class="p">))</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="n">unit_cycle</span> <span class="o">=</span> <span class="mi">0</span> |
| <span class="n">adjusted_cycle</span> <span class="o">=</span> <span class="p">(</span><span class="n">unit_cycle</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span><span class="p">))</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> |
| <span class="k">return</span> <span class="n">adjusted_cycle</span> |
| </pre></div> |
| </div> |
| <p>We look an example of a slanted triangular schedule that increases from a learning rate of 1 to 2, and back to 1 over 1000 iterations. Since we set <code class="docutils literal"><span class="pre">inc_fraction=0.2</span></code>, 200 iterations are used for the increasing stage, and 800 for the decreasing stage. After this, the schedule stays at the lower bound indefinitely.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">TriangularSchedule</span><span class="p">(</span><span class="n">min_lr</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">cycle_length</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">inc_fraction</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_triangular.png"> <!--notebook-skip-line--></img></p> |
| </div> |
| <div class="section" id="cosine"> |
| <span id="cosine"></span><h3>Cosine<a class="headerlink" href="#cosine" title="Permalink to this headline">¶</a></h3> |
| <p>Continuing with the idea that smooth decay profiles give improved performance over stepwise decay, <a class="reference external" href="https://arxiv.org/abs/1608.03983">Ilya Loshchilov, Frank Hutter (2016)</a> used <strong>“cosine annealing”</strong> schedules to good effect. As with triangular schedules, the original idea was that this should be used as part of a cyclical schedule, but we begin by implementing the cosine annealing component before the full Stochastic Gradient Descent with Warm Restarts (SGDR) method later in the tutorial.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">CosineAnnealingSchedule</span><span class="p">():</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">min_lr</span><span class="p">,</span> <span class="n">max_lr</span><span class="p">,</span> <span class="n">cycle_length</span><span class="p">):</span> |
| <span class="sd">"""</span> |
| <span class="sd"> min_lr: lower bound for learning rate (float)</span> |
| <span class="sd"> max_lr: upper bound for learning rate (float)</span> |
| <span class="sd"> cycle_length: iterations between start and finish (int)</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> <span class="o">=</span> <span class="n">min_lr</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">max_lr</span> <span class="o">=</span> <span class="n">max_lr</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span> <span class="o">=</span> <span class="n">cycle_length</span> |
| |
| <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span> |
| <span class="k">if</span> <span class="n">iteration</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span><span class="p">:</span> |
| <span class="n">unit_cycle</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">iteration</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">cycle_length</span><span class="p">))</span> <span class="o">/</span> <span class="mi">2</span> |
| <span class="n">adjusted_cycle</span> <span class="o">=</span> <span class="p">(</span><span class="n">unit_cycle</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span><span class="p">))</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> |
| <span class="k">return</span> <span class="n">adjusted_cycle</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> |
| </pre></div> |
| </div> |
| <p>We look at an example of a cosine annealing schedule that smoothing decreases from a learning rate of 2 to 1 across 1000 iterations. After this, the schedule stays at the lower bound indefinietly.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">CosineAnnealingSchedule</span><span class="p">(</span><span class="n">min_lr</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">cycle_length</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_cosine.png"/> <!--notebook-skip-line--></p> |
| </div> |
| </div> |
| <div class="section" id="custom-schedule-modifiers"> |
| <span id="custom-schedule-modifiers"></span><h2>Custom Schedule Modifiers<a class="headerlink" href="#custom-schedule-modifiers" title="Permalink to this headline">¶</a></h2> |
| <p>We now take a look some adjustments that can be made to existing schedules. We see how to add linear warm-up and its compliment linear cool-down, before using this to implement the “1-Cycle” schedule used by <a class="reference external" href="https://arxiv.org/abs/1708.07120">Leslie N. Smith, Nicholay Topin (2017)</a> for “super-convergence”. We then look at cyclical schedules and implement the original cyclical schedule from <a class="reference external" href="https://arxiv.org/abs/1506.01186">Leslie N. Smith (2015)</a> before finishing with a look at <a class="reference external" href="https://arxiv.org/abs/1608.03983">“SGDR: Stochastic Gradient Descent with Warm Restarts” by Ilya Loshchilov, Frank Hutter (2016)</a>.</p> |
| <p>Unlike the schedules above and those implemented in <code class="docutils literal"><span class="pre">mx.lr_scheduler</span></code>, these classes are designed to modify existing schedules so they take the argument <code class="docutils literal"><span class="pre">schedule</span></code> (for initialized schedules) or <code class="docutils literal"><span class="pre">schedule_class</span></code> when being initialized.</p> |
| <div class="section" id="warm-up"> |
| <span id="warm-up"></span><h3>Warm-Up<a class="headerlink" href="#warm-up" title="Permalink to this headline">¶</a></h3> |
| <p>Using the idea of linear warm-up of the learning rate proposed in <a class="reference external" href="https://arxiv.org/abs/1706.02677">“Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour” by Priya Goyal et al. (2017)</a>, we implement a wrapper class that adds warm-up to an existing schedule. Going from <code class="docutils literal"><span class="pre">start_lr</span></code> to the initial learning rate of the <code class="docutils literal"><span class="pre">schedule</span></code> over <code class="docutils literal"><span class="pre">length</span></code> iterations, this adjustment is useful when training with large batch sizes.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LinearWarmUp</span><span class="p">():</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schedule</span><span class="p">,</span> <span class="n">start_lr</span><span class="p">,</span> <span class="n">length</span><span class="p">):</span> |
| <span class="sd">"""</span> |
| <span class="sd"> schedule: a pre-initialized schedule (e.g. TriangularSchedule(min_lr=0.5, max_lr=2, cycle_length=500))</span> |
| <span class="sd"> start_lr: learning rate used at start of the warm-up (float)</span> |
| <span class="sd"> length: number of iterations used for the warm-up (int)</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span> <span class="o">=</span> <span class="n">schedule</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">start_lr</span> <span class="o">=</span> <span class="n">start_lr</span> |
| <span class="c1"># calling mx.lr_scheduler.LRScheduler effects state, so calling a copy</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">finish_lr</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">schedule</span><span class="p">)(</span><span class="mi">0</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">length</span> <span class="o">=</span> <span class="n">length</span> |
| |
| <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span> |
| <span class="k">if</span> <span class="n">iteration</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">:</span> |
| <span class="k">return</span> <span class="n">iteration</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">finish_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_lr</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_lr</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span><span class="p">(</span><span class="n">iteration</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>As an example, we add a linear warm-up of the learning rate (from 0 to 1 over 250 iterations) to a stepwise decay schedule. We first create the <code class="docutils literal"><span class="pre">MultiFactorScheduler</span></code> (and set the <code class="docutils literal"><span class="pre">base_lr</span></code>) and then pass it to <code class="docutils literal"><span class="pre">LinearWarmUp</span></code> to add the warm-up at the start. We can use <code class="docutils literal"><span class="pre">LinearWarmUp</span></code> with any other schedule including <code class="docutils literal"><span class="pre">CosineAnnealingSchedule</span></code>.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">MultiFactorScheduler</span><span class="p">(</span><span class="n">step</span><span class="o">=</span><span class="p">[</span><span class="mi">250</span><span class="p">,</span> <span class="mi">750</span><span class="p">,</span> <span class="mi">900</span><span class="p">],</span> <span class="n">factor</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="n">schedule</span><span class="o">.</span><span class="n">base_lr</span> <span class="o">=</span> <span class="mi">1</span> |
| <span class="n">schedule</span> <span class="o">=</span> <span class="n">LinearWarmUp</span><span class="p">(</span><span class="n">schedule</span><span class="p">,</span> <span class="n">start_lr</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="mi">250</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_warmup.png"/> <!--notebook-skip-line--></p> |
| </div> |
| <div class="section" id="cool-down"> |
| <span id="cool-down"></span><h3>Cool-Down<a class="headerlink" href="#cool-down" title="Permalink to this headline">¶</a></h3> |
| <p>Similarly, we could add a linear cool-down period to our schedule and this is used in the “1-Cycle” schedule proposed by <a class="reference external" href="https://arxiv.org/abs/1708.07120">Leslie N. Smith, Nicholay Topin (2017)</a> to train neural networks very quickly in certain circumstances (coined “super-convergence”). We reduce the learning rate from its value at <code class="docutils literal"><span class="pre">start_idx</span></code> of <code class="docutils literal"><span class="pre">schedule</span></code> to <code class="docutils literal"><span class="pre">finish_lr</span></code> over a period of <code class="docutils literal"><span class="pre">length</span></code>, and then maintain <code class="docutils literal"><span class="pre">finish_lr</span></code> thereafter.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LinearCoolDown</span><span class="p">():</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schedule</span><span class="p">,</span> <span class="n">finish_lr</span><span class="p">,</span> <span class="n">start_idx</span><span class="p">,</span> <span class="n">length</span><span class="p">):</span> |
| <span class="sd">"""</span> |
| <span class="sd"> schedule: a pre-initialized schedule (e.g. TriangularSchedule(min_lr=0.5, max_lr=2, cycle_length=500))</span> |
| <span class="sd"> finish_lr: learning rate used at end of the cool-down (float)</span> |
| <span class="sd"> start_idx: iteration to start the cool-down (int)</span> |
| <span class="sd"> length: number of iterations used for the cool-down (int)</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span> <span class="o">=</span> <span class="n">schedule</span> |
| <span class="c1"># calling mx.lr_scheduler.LRScheduler effects state, so calling a copy</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">start_lr</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">schedule</span><span class="p">)(</span><span class="n">start_idx</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">finish_lr</span> <span class="o">=</span> <span class="n">finish_lr</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">start_idx</span> <span class="o">=</span> <span class="n">start_idx</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">finish_idx</span> <span class="o">=</span> <span class="n">start_idx</span> <span class="o">+</span> <span class="n">length</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">length</span> <span class="o">=</span> <span class="n">length</span> |
| |
| <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span> |
| <span class="k">if</span> <span class="n">iteration</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_idx</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span><span class="p">(</span><span class="n">iteration</span><span class="p">)</span> |
| <span class="k">elif</span> <span class="n">iteration</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finish_idx</span><span class="p">:</span> |
| <span class="k">return</span> <span class="p">(</span><span class="n">iteration</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_idx</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">finish_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_lr</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">length</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_lr</span> |
| <span class="k">else</span><span class="p">:</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">finish_lr</span> |
| </pre></div> |
| </div> |
| <p>As an example, we apply learning rate cool-down to a <code class="docutils literal"><span class="pre">MultiFactorScheduler</span></code>. Starting the cool-down at iteration 1000, we reduce the learning rate linearly from 0.125 to 0.001 over 500 iterations, and hold the learning rate at 0.001 after this.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">MultiFactorScheduler</span><span class="p">(</span><span class="n">step</span><span class="o">=</span><span class="p">[</span><span class="mi">250</span><span class="p">,</span> <span class="mi">750</span><span class="p">,</span> <span class="mi">900</span><span class="p">],</span> <span class="n">factor</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="n">schedule</span><span class="o">.</span><span class="n">base_lr</span> <span class="o">=</span> <span class="mi">1</span> |
| <span class="n">schedule</span> <span class="o">=</span> <span class="n">LinearCoolDown</span><span class="p">(</span><span class="n">schedule</span><span class="p">,</span> <span class="n">finish_lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">start_idx</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_cooldown.png"/> <!--notebook-skip-line--></p> |
| <div class="section" id="cycle-for-super-convergence"> |
| <span id="cycle-for-super-convergence"></span><h4>1-Cycle: for “Super-Convergence”<a class="headerlink" href="#cycle-for-super-convergence" title="Permalink to this headline">¶</a></h4> |
| <p>So we can implement the “1-Cycle” schedule proposed by <a class="reference external" href="https://arxiv.org/abs/1708.07120">Leslie N. Smith, Nicholay Topin (2017)</a> we use a single and symmetric cycle of the triangular schedule above (i.e. <code class="docutils literal"><span class="pre">inc_fraction=0.5</span></code>), followed by a cool-down period of <code class="docutils literal"><span class="pre">cooldown_length</span></code> iterations.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">OneCycleSchedule</span><span class="p">():</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_lr</span><span class="p">,</span> <span class="n">max_lr</span><span class="p">,</span> <span class="n">cycle_length</span><span class="p">,</span> <span class="n">cooldown_length</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">finish_lr</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span> |
| <span class="sd">"""</span> |
| <span class="sd"> start_lr: lower bound for learning rate in triangular cycle (float)</span> |
| <span class="sd"> max_lr: upper bound for learning rate in triangular cycle (float)</span> |
| <span class="sd"> cycle_length: iterations between start and finish of triangular cycle: 2x 'stepsize' (int)</span> |
| <span class="sd"> cooldown_length: number of iterations used for the cool-down (int)</span> |
| <span class="sd"> finish_lr: learning rate used at end of the cool-down (float)</span> |
| <span class="sd"> """</span> |
| <span class="k">if</span> <span class="p">(</span><span class="n">cooldown_length</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">finish_lr</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Must specify finish_lr when using cooldown_length > 0."</span><span class="p">)</span> |
| <span class="k">if</span> <span class="p">(</span><span class="n">cooldown_length</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">finish_lr</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">):</span> |
| <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Must specify cooldown_length > 0 when using finish_lr."</span><span class="p">)</span> |
| |
| <span class="n">finish_lr</span> <span class="o">=</span> <span class="n">finish_lr</span> <span class="k">if</span> <span class="p">(</span><span class="n">cooldown_length</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">else</span> <span class="n">start_lr</span> |
| <span class="n">schedule</span> <span class="o">=</span> <span class="n">TriangularSchedule</span><span class="p">(</span><span class="n">min_lr</span><span class="o">=</span><span class="n">start_lr</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="n">max_lr</span><span class="p">,</span> <span class="n">cycle_length</span><span class="o">=</span><span class="n">cycle_length</span><span class="p">)</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span> <span class="o">=</span> <span class="n">LinearCoolDown</span><span class="p">(</span><span class="n">schedule</span><span class="p">,</span> <span class="n">finish_lr</span><span class="o">=</span><span class="n">finish_lr</span><span class="p">,</span> <span class="n">start_idx</span><span class="o">=</span><span class="n">cycle_length</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="n">cooldown_length</span><span class="p">)</span> |
| |
| <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span> |
| <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span><span class="p">(</span><span class="n">iteration</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p>As an example, we linearly increase and then decrease the learning rate from 0.1 to 0.5 and back over 500 iterations (i.e. single triangular cycle), before reducing the learning rate further to 0.001 over the next 750 iterations (i.e. cool-down).</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">OneCycleSchedule</span><span class="p">(</span><span class="n">start_lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">cycle_length</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">cooldown_length</span><span class="o">=</span><span class="mi">750</span><span class="p">,</span> <span class="n">finish_lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_onecycle.png"/> <!--notebook-skip-line--></p> |
| </div> |
| </div> |
| <div class="section" id="cyclical"> |
| <span id="cyclical"></span><h3>Cyclical<a class="headerlink" href="#cyclical" title="Permalink to this headline">¶</a></h3> |
| <p>Originally proposed by <a class="reference external" href="https://arxiv.org/abs/1506.01186">Leslie N. Smith (2015)</a>, the idea of cyclically increasing and decreasing the learning rate has been shown to give faster convergence and more optimal solutions. We implement a wrapper class that loops existing cycle-based schedules such as <code class="docutils literal"><span class="pre">TriangularSchedule</span></code> and <code class="docutils literal"><span class="pre">CosineAnnealingSchedule</span></code> to provide infinitely repeating schedules. We pass the schedule class (rather than an instance) because one feature of the <code class="docutils literal"><span class="pre">CyclicalSchedule</span></code> is to vary the <code class="docutils literal"><span class="pre">cycle_length</span></code> over time as seen in <a class="reference external" href="https://arxiv.org/abs/1608.03983">Ilya Loshchilov, Frank Hutter (2016)</a> using <code class="docutils literal"><span class="pre">cycle_length_decay</span></code>. Another feature is the ability to decay the cycle magnitude over time with <code class="docutils literal"><span class="pre">cycle_magnitude_decay</span></code>.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">CyclicalSchedule</span><span class="p">():</span> |
| <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schedule_class</span><span class="p">,</span> <span class="n">cycle_length</span><span class="p">,</span> <span class="n">cycle_length_decay</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">cycle_magnitude_decay</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> |
| <span class="sd">"""</span> |
| <span class="sd"> schedule_class: class of schedule, expected to take `cycle_length` argument</span> |
| <span class="sd"> cycle_length: iterations used for initial cycle (int)</span> |
| <span class="sd"> cycle_length_decay: factor multiplied to cycle_length each cycle (float)</span> |
| <span class="sd"> cycle_magnitude_decay: factor multiplied learning rate magnitudes each cycle (float)</span> |
| <span class="sd"> kwargs: passed to the schedule_class</span> |
| <span class="sd"> """</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">schedule_class</span> <span class="o">=</span> <span class="n">schedule_class</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">length</span> <span class="o">=</span> <span class="n">cycle_length</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">length_decay</span> <span class="o">=</span> <span class="n">cycle_length_decay</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_decay</span> <span class="o">=</span> <span class="n">cycle_magnitude_decay</span> |
| <span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span> <span class="o">=</span> <span class="n">kwargs</span> |
| |
| <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span> |
| <span class="n">cycle_idx</span> <span class="o">=</span> <span class="mi">0</span> |
| <span class="n">cycle_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span> |
| <span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">length</span> |
| <span class="k">while</span> <span class="n">idx</span> <span class="o"><=</span> <span class="n">iteration</span><span class="p">:</span> |
| <span class="n">cycle_length</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">cycle_length</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">length_decay</span><span class="p">)</span> |
| <span class="n">cycle_idx</span> <span class="o">+=</span> <span class="mi">1</span> |
| <span class="n">idx</span> <span class="o">+=</span> <span class="n">cycle_length</span> |
| <span class="n">cycle_offset</span> <span class="o">=</span> <span class="n">iteration</span> <span class="o">-</span> <span class="n">idx</span> <span class="o">+</span> <span class="n">cycle_length</span> |
| |
| <span class="n">schedule</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">schedule_class</span><span class="p">(</span><span class="n">cycle_length</span><span class="o">=</span><span class="n">cycle_length</span><span class="p">,</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span><span class="p">)</span> |
| <span class="k">return</span> <span class="n">schedule</span><span class="p">(</span><span class="n">cycle_offset</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_decay</span><span class="o">**</span><span class="n">cycle_idx</span> |
| </pre></div> |
| </div> |
| <p>As an example, we implement the triangular cyclical schedule presented in <a class="reference external" href="https://arxiv.org/abs/1506.01186">“Cyclical Learning Rates for Training Neural Networks” by Leslie N. Smith (2015)</a>. We use slightly different terminology to the paper here because we use <code class="docutils literal"><span class="pre">cycle_length</span></code> that is twice the ‘stepsize’ used in the paper. We repeat cycles, each with a length of 500 iterations and lower and upper learning rate bounds of 0.5 and 2 respectively.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">CyclicalSchedule</span><span class="p">(</span><span class="n">TriangularSchedule</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">cycle_length</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_cyclical.png"/> <!--notebook-skip-line--></p> |
| <p>And lastly, we implement the scheduled used in <a class="reference external" href="https://arxiv.org/abs/1608.03983">“SGDR: Stochastic Gradient Descent with Warm Restarts” by Ilya Loshchilov, Frank Hutter (2016)</a>. We repeat cosine annealing schedules, but each time we halve the magnitude and double the cycle length.</p> |
| <div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">schedule</span> <span class="o">=</span> <span class="n">CyclicalSchedule</span><span class="p">(</span><span class="n">CosineAnnealingSchedule</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> |
| <span class="n">cycle_length</span><span class="o">=</span><span class="mi">250</span><span class="p">,</span> <span class="n">cycle_length_decay</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">cycle_magnitude_decay</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> |
| <span class="n">plot_schedule</span><span class="p">(</span><span class="n">schedule</span><span class="p">)</span> |
| </pre></div> |
| </div> |
| <p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/adv_sgdr.png"/> <!--notebook-skip-line--></p> |
| <p><strong><em>Want to learn more?</em></strong> Checkout the “Learning Rate Schedules” tutorial for a more basic overview of learning rates found in <code class="docutils literal"><span class="pre">mx.lr_scheduler</span></code>, and an example of how to use them while training your own models.</p> |
| <div class="btn-group" role="group"> |
| <div class="download-btn"><a download="learning_rate_schedules_advanced.ipynb" href="learning_rate_schedules_advanced.ipynb"><span class="glyphicon glyphicon-download-alt"></span> learning_rate_schedules_advanced.ipynb</a></div></div></div> |
| </div> |
| </div> |
| </div> |
| </div> |
| <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> |
| <div class="sphinxsidebarwrapper"> |
| <h3><a href="../../index.html">Table Of Contents</a></h3> |
| <ul> |
| <li><a class="reference internal" href="#">Advanced Learning Rate Schedules</a><ul> |
| <li><a class="reference internal" href="#custom-schedule-shapes">Custom Schedule Shapes</a><ul> |
| <li><a class="reference internal" href="#slanted-triangular">(Slanted) Triangular</a></li> |
| <li><a class="reference internal" href="#cosine">Cosine</a></li> |
| </ul> |
| </li> |
| <li><a class="reference internal" href="#custom-schedule-modifiers">Custom Schedule Modifiers</a><ul> |
| <li><a class="reference internal" href="#warm-up">Warm-Up</a></li> |
| <li><a class="reference internal" href="#cool-down">Cool-Down</a><ul> |
| <li><a class="reference internal" href="#cycle-for-super-convergence">1-Cycle: for “Super-Convergence”</a></li> |
| </ul> |
| </li> |
| <li><a class="reference internal" href="#cyclical">Cyclical</a></li> |
| </ul> |
| </li> |
| </ul> |
| </li> |
| </ul> |
| </div> |
| </div> |
| </div><div class="footer"> |
| <div class="section-disclaimer"> |
| <div class="container"> |
| <div> |
| <img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/> |
| <p> |
| Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p> |
| <p> |
| "Copyright © 2017-2018, The Apache Software Foundation |
| Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation." |
| </p> |
| </div> |
| </div> |
| </div> |
| </div> <!-- pagename != index --> |
| </div> |
| <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |
| <script src="../../_static/js/sidebar.js" type="text/javascript"></script> |
| <script src="../../_static/js/search.js" type="text/javascript"></script> |
| <script src="../../_static/js/navbar.js" type="text/javascript"></script> |
| <script src="../../_static/js/clipboard.min.js" type="text/javascript"></script> |
| <script src="../../_static/js/copycode.js" type="text/javascript"></script> |
| <script src="../../_static/js/page.js" type="text/javascript"></script> |
| <script src="../../_static/js/docversion.js" type="text/javascript"></script> |
| <script type="text/javascript"> |
| $('body').ready(function () { |
| $('body').css('visibility', 'visible'); |
| }); |
| </script> |
| </body> |
| </html> |