blob: 194a9604624acea7e53a80bf7ca6d837f05fc927 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "e3da3cad",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"\n",
"# Learning Rate Finder\n",
"\n",
"Setting the learning rate for stochastic gradient descent (SGD) is crucially important when training neural network because it controls both the speed of convergence and the ultimate performance of the network. Set the learning too low and you could be twiddling your thumbs for quite some time as the parameters update very slowly. Set it too high and the updates will skip over optimal solutions, or worse the optimizer might not converge at all!\n",
"\n",
"Leslie Smith from the U.S. Naval Research Laboratory presented a method for finding a good learning rate in a paper called [\"Cyclical Learning Rates for Training Neural Networks\"](https://arxiv.org/abs/1506.01186). We implement this method in MXNet (with the Gluon API) and create a 'Learning Rate Finder' which you can use while training your own networks. We take a look at the central idea of the paper, cyclical learning rate schedules, in the ['Advanced Learning Rate Schedules'](/api/python/docs/tutorials/packages/gluon/training/learning_rates/learning_rate_schedules_advanced.html) tutorial.\n",
"\n",
"## Simple Idea\n",
"\n",
"Given an initialized network, a defined loss and a training dataset we take the following steps:\n",
"\n",
"1. Train one batch at a time (a.k.a. an iteration)\n",
"2. Start with a very small learning rate (e.g. 0.000001) and slowly increase it every iteration\n",
"3. Record the training loss and continue until we see the training loss diverge\n",
"\n",
"We then analyse the results by plotting a graph of the learning rate against the training loss as seen below (taking note of the log scales).\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_finder/finder_plot_w_annotations.png\" width=\"500px\"/> <!--notebook-skip-line-->\n",
"\n",
"As expected, for very small learning rates we don't see much change in the loss as the parameter updates are negligible. At a learning rate of 0.001, we start to see the loss fall. Setting the initial learning rate here is reasonable, but we still have the potential to learn faster. We observe a drop in the loss up until 0.1 where the loss appears to diverge. We want to set the initial learning rate as high as possible before the loss becomes unstable, so we choose a learning rate of 0.05.\n",
"\n",
"## Epoch to Iteration\n",
"\n",
"Usually, our unit of work is an epoch (a full pass through the dataset) and the learning rate would typically be held constant throughout the epoch. With the Learning Rate Finder (and cyclical learning rate schedules) we are required to vary the learning rate every iteration. As such we structure our training code so that a single iteration can be run with a given learning rate. You can implement Learner as you wish. Just initialize the network, define the loss and trainer in `__init__` and keep your training logic for a single batch in `iteration`."
]
},
{
"cell_type": "markdown",
"id": "ec7c869a",
"metadata": {},
"source": [
"```python\n",
"import mxnet as mx\n",
"\n",
"# Set seed for reproducibility\n",
"mx.random.seed(42)\n",
"\n",
"class Learner():\n",
" def __init__(self, net, data_loader, ctx):\n",
" \"\"\"\n",
" :param net: network (mx.gluon.Block)\n",
" :param data_loader: training data loader (mx.gluon.data.DataLoader)\n",
" :param ctx: context (mx.gpu or mx.cpu)\n",
" \"\"\"\n",
" self.net = net\n",
" self.data_loader = data_loader\n",
" self.ctx = ctx\n",
" # So we don't need to be in `for batch in data_loader` scope\n",
" # and can call for next batch in `iteration`\n",
" self.data_loader_iter = iter(self.data_loader)\n",
" self.net.initialize(mx.init.Xavier(), ctx=self.ctx)\n",
" self.loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss()\n",
" self.trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .001})\n",
"\n",
" def iteration(self, lr=None, take_step=True):\n",
" \"\"\"\n",
" :param lr: learning rate to use for iteration (float)\n",
" :param take_step: take trainer step to update weights (boolean)\n",
" :return: iteration loss (float)\n",
" \"\"\"\n",
" # Update learning rate if different this iteration\n",
" if lr and (lr != self.trainer.learning_rate):\n",
" self.trainer.set_learning_rate(lr)\n",
" # Get next batch, and move context (e.g. to GPU if set)\n",
" data, label = next(self.data_loader_iter)\n",
" data = data.as_in_context(self.ctx)\n",
" label = label.as_in_context(self.ctx)\n",
" # Standard forward and backward pass\n",
" with mx.autograd.record():\n",
" output = self.net(data)\n",
" loss = self.loss_fn(output, label)\n",
" loss.backward()\n",
" # Update parameters\n",
" if take_step: self.trainer.step(data.shape[0])\n",
" # Set and return loss.\n",
" self.iteration_loss = mx.nd.mean(loss).asscalar()\n",
" return self.iteration_loss\n",
"\n",
" def close(self):\n",
" # Close open iterator and associated workers\n",
" self.data_loader_iter.shutdown()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "fb8d9911",
"metadata": {},
"source": [
"We also adjust our `DataLoader` so that it continuously provides batches of data and doesn't stop after a single epoch. We can then call `iteration` as many times as required for the loss to diverge as part of the Learning Rate Finder process. We implement a custom `BatchSampler` for this, that keeps returning random indices of samples to be included in the next batch. We use the CIFAR-10 dataset for image classification to test our Learning Rate Finder."
]
},
{
"cell_type": "markdown",
"id": "692e5e65",
"metadata": {},
"source": [
"```python\n",
"from mxnet.gluon.data.vision import transforms\n",
"\n",
"transform = transforms.Compose([\n",
" # Switches HWC to CHW, and converts to `float32`\n",
" transforms.ToTensor(),\n",
" # Channel-wise, using pre-computed means and stds\n",
" transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],\n",
" std=[0.2023, 0.1994, 0.2010])\n",
"])\n",
"\n",
"dataset = mx.gluon.data.vision.datasets.CIFAR10(train=True).transform_first(transform)\n",
"\n",
"class ContinuousBatchSampler():\n",
" def __init__(self, sampler, batch_size):\n",
" self._sampler = sampler\n",
" self._batch_size = batch_size\n",
"\n",
" def __iter__(self):\n",
" batch = []\n",
" while True:\n",
" for i in self._sampler:\n",
" batch.append(i)\n",
" if len(batch) == self._batch_size:\n",
" yield batch\n",
" batch = []\n",
"\n",
"sampler = mx.gluon.data.RandomSampler(len(dataset))\n",
"batch_sampler = ContinuousBatchSampler(sampler, batch_size=128)\n",
"data_loader = mx.gluon.data.DataLoader(dataset, batch_sampler=batch_sampler)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "1ced7167",
"metadata": {},
"source": [
"## Implementation\n",
"\n",
"With preparation complete, we're ready to write our Learning Rate Finder that wraps the `Learner` we defined above. We implement a `find` method for the procedure, and `plot` for the visualization. Starting with a very low learning rate as defined by `lr_start` we train one iteration at a time and keep multiplying the learning rate by `lr_multiplier`. We analyse the loss and continue until it diverges according to `LRFinderStoppingCriteria` (which is defined later on). You may also notice that we save the parameters and state of the optimizer before the process and restore afterwards. This is so the Learning Rate Finder process doesn't impact the state of the model, and can be used at any point during training."
]
},
{
"cell_type": "markdown",
"id": "f9cae89c",
"metadata": {},
"source": [
"```python\n",
"from matplotlib import pyplot as plt\n",
"\n",
"class LRFinder():\n",
" def __init__(self, learner):\n",
" \"\"\"\n",
" :param learner: able to take single iteration with given learning rate and return loss\n",
" and save and load parameters of the network (Learner)\n",
" \"\"\"\n",
" self.learner = learner\n",
"\n",
" def find(self, lr_start=1e-6, lr_multiplier=1.1, smoothing=0.3):\n",
" \"\"\"\n",
" :param lr_start: learning rate to start search (float)\n",
" :param lr_multiplier: factor the learning rate is multiplied by at each step of search (float)\n",
" :param smoothing: amount of smoothing applied to loss for stopping criteria (float)\n",
" :return: learning rate and loss pairs (list of (float, float) tuples)\n",
" \"\"\"\n",
" # Used to initialize weights; pass data, but don't take step.\n",
" # Would expect for new model with lazy weight initialization\n",
" self.learner.iteration(take_step=False)\n",
" # Used to initialize trainer (if no step has been taken)\n",
" if not self.learner.trainer._kv_initialized:\n",
" self.learner.trainer._init_kvstore()\n",
" # Store params and optimizer state for restore after lr_finder procedure\n",
" # Useful for applying the method partway through training, not just for initialization of lr.\n",
" self.learner.net.save_parameters(\"lr_finder.params\")\n",
" self.learner.trainer.save_states(\"lr_finder.state\")\n",
" lr = lr_start\n",
" self.results = [] # List of (lr, loss) tuples\n",
" stopping_criteria = LRFinderStoppingCriteria(smoothing)\n",
" while True:\n",
" # Run iteration, and block until loss is calculated.\n",
" loss = self.learner.iteration(lr)\n",
" self.results.append((lr, loss))\n",
" if stopping_criteria(loss):\n",
" break\n",
" lr = lr * lr_multiplier\n",
" # Restore params (as finder changed them)\n",
" self.learner.net.load_parameters(\"lr_finder.params\", ctx=self.learner.ctx)\n",
" self.learner.trainer.load_states(\"lr_finder.state\")\n",
" return self.results\n",
"\n",
" def plot(self):\n",
" lrs = [e[0] for e in self.results]\n",
" losses = [e[1] for e in self.results]\n",
" plt.figure(figsize=(6,8))\n",
" plt.scatter(lrs, losses)\n",
" plt.xlabel(\"Learning Rate\")\n",
" plt.ylabel(\"Loss\")\n",
" plt.xscale('log')\n",
" plt.yscale('log')\n",
" axes = plt.gca()\n",
" axes.set_xlim([lrs[0], lrs[-1]])\n",
" y_lower = min(losses) * 0.8\n",
" y_upper = losses[0] * 4\n",
" axes.set_ylim([y_lower, y_upper])\n",
" plt.show()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "d5ef0caa",
"metadata": {},
"source": [
"You can define the `LRFinderStoppingCriteria` as you wish, but empirical testing suggests using a smoothed average gives a more consistent stopping rule (see `smoothing`). We stop when the smoothed average of the loss exceeds twice the initial loss, assuming there have been a minimum number of iterations (see `min_iter`)."
]
},
{
"cell_type": "markdown",
"id": "b40464d8",
"metadata": {},
"source": [
"```python\n",
"class LRFinderStoppingCriteria():\n",
" def __init__(self, smoothing=0.3, min_iter=20):\n",
" \"\"\"\n",
" :param smoothing: applied to running mean which is used for thresholding (float)\n",
" :param min_iter: minimum number of iterations before early stopping can occur (int)\n",
" \"\"\"\n",
" self.smoothing = smoothing\n",
" self.min_iter = min_iter\n",
" self.first_loss = None\n",
" self.running_mean = None\n",
" self.counter = 0\n",
"\n",
" def __call__(self, loss):\n",
" \"\"\"\n",
" :param loss: from single iteration (float)\n",
" :return: indicator to stop (boolean)\n",
" \"\"\"\n",
" self.counter += 1\n",
" if self.first_loss is None:\n",
" self.first_loss = loss\n",
" if self.running_mean is None:\n",
" self.running_mean = loss\n",
" else:\n",
" self.running_mean = ((1 - self.smoothing) * loss) + (self.smoothing * self.running_mean)\n",
" return (self.running_mean > self.first_loss * 2) and (self.counter >= self.min_iter)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "8c1ad465",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"Using a Pre-activation ResNet-18 from the Gluon model zoo, we instantiate our Learner and fire up our Learning Rate Finder!"
]
},
{
"cell_type": "markdown",
"id": "d78fc738",
"metadata": {},
"source": [
"```python\n",
"ctx = mx.gpu() if mx.context.num_gpus() else mx.cpu()\n",
"net = mx.gluon.model_zoo.vision.resnet18_v2(classes=10)\n",
"learner = Learner(net=net, data_loader=data_loader, ctx=ctx)\n",
"lr_finder = LRFinder(learner)\n",
"lr_finder.find(lr_start=1e-6)\n",
"lr_finder.plot()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "b4ec7f80",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_finder/finder_plot.png) <!--notebook-skip-line-->\n",
"\n",
"\n",
"As discussed before, we should select a learning rate where the loss is falling (i.e. from 0.001 to 0.05) but before the loss starts to diverge (i.e. 0.1). We prefer higher learning rates where possible, so we select an initial learning rate of 0.05. Just as a test, we will run 500 epochs using this learning rate and evaluate the loss on the final batch. As we're working with a single batch of 128 samples, the variance of the loss estimates will be reasonably high, but it will give us a general idea. We save the initialized parameters for a later comparison with other learning rates."
]
},
{
"cell_type": "markdown",
"id": "d5e1f07c",
"metadata": {},
"source": [
"```python\n",
"learner.net.save_parameters(\"net.params\")\n",
"lr = 0.05\n",
"\n",
"for iter_idx in range(300):\n",
" learner.iteration(lr=lr)\n",
" if ((iter_idx % 100) == 0):\n",
" print(\"Iteration: {}, Loss: {:.5g}\".format(iter_idx, learner.iteration_loss))\n",
"print(\"Final Loss: {:.5g}\".format(learner.iteration_loss))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "f34558e4",
"metadata": {},
"source": [
"Iteration: 0, Loss: 2.785 <!--notebook-skip-line-->\n",
"\n",
"Iteration: 100, Loss: 1.6653 <!--notebook-skip-line-->\n",
"\n",
"Iteration: 200, Loss: 1.4891 <!--notebook-skip-line-->\n",
"\n",
"\n",
"Final Loss: 1.1812 <!--notebook-skip-line-->\n",
"\n",
"\n",
"We see a sizable drop in the loss from approx. 2.7 to 1.2.\n",
"\n",
"And now we have a baseline, let's see what happens when we train with a learning rate that's higher than advisable at 0.5."
]
},
{
"cell_type": "markdown",
"id": "7914a6a5",
"metadata": {},
"source": [
"```python\n",
"net = mx.gluon.model_zoo.vision.resnet18_v2(classes=10)\n",
"learner = Learner(net=net, data_loader=data_loader, ctx=ctx)\n",
"learner.net.load_parameters(\"net.params\", ctx=ctx)\n",
"lr = 0.5\n",
"\n",
"for iter_idx in range(300):\n",
" learner.iteration(lr=lr)\n",
" if ((iter_idx % 100) == 0):\n",
" print(\"Iteration: {}, Loss: {:.5g}\".format(iter_idx, learner.iteration_loss))\n",
"print(\"Final Loss: {:.5g}\".format(learner.iteration_loss))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "c98469f0",
"metadata": {},
"source": [
"Iteration: 0, Loss: 2.6469 <!--notebook-skip-line-->\n",
"\n",
"Iteration: 100, Loss: 1.9666 <!--notebook-skip-line-->\n",
"\n",
"Iteration: 200, Loss: 1.6919 <!--notebook-skip-line-->\n",
"\n",
"\n",
"Final Loss: 1.366 <!--notebook-skip-line-->\n",
"\n",
"\n",
"We still observe a fall in the loss but aren't able to reach as low as before.\n",
"\n",
"And lastly, we see how the model trains with a more conservative learning rate of 0.005."
]
},
{
"cell_type": "markdown",
"id": "c38a5be5",
"metadata": {},
"source": [
"```python\n",
"net = mx.gluon.model_zoo.vision.resnet18_v2(classes=10)\n",
"learner = Learner(net=net, data_loader=data_loader, ctx=ctx)\n",
"learner.net.load_parameters(\"net.params\", ctx=ctx)\n",
"lr = 0.005\n",
"\n",
"for iter_idx in range(300):\n",
" learner.iteration(lr=lr)\n",
" if ((iter_idx % 100) == 0):\n",
" print(\"Iteration: {}, Loss: {:.5g}\".format(iter_idx, learner.iteration_loss))\n",
"print(\"Final Loss: {:.5g}\".format(learner.iteration_loss))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "e04b1fa4",
"metadata": {},
"source": [
"Iteration: 0, Loss: 2.605 <!--notebook-skip-line-->\n",
"\n",
"Iteration: 100, Loss: 1.8621 <!--notebook-skip-line-->\n",
"\n",
"Iteration: 200, Loss: 1.6316 <!--notebook-skip-line-->\n",
"\n",
"\n",
"Final Loss: 1.2919 <!--notebook-skip-line-->\n",
"\n",
"\n",
"Although we get quite similar results to when we set the learning rate at 0.05 (because we're still in the region of falling loss on the Learning Rate Finder plot), we can still optimize our network faster using a slightly higher rate.\n",
"\n",
"## Wrap Up\n",
"\n",
"Give Learning Rate Finder a try on your current projects, and experiment with the different learning rate schedules found in the [basic learning rate tutorial](/api/python/docs/tutorials/packages/gluon/training/learning_rates/learning_rate_schedules.html) and the [advanced learning rate tutorial](/api/python/docs/tutorials/packages/gluon/training/learning_rates/learning_rate_schedules_advanced.html).\n",
"\n",
"<!-- INSERT SOURCE DOWNLOAD BUTTONS -->"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}