blob: ed206ac12d2eea47f9b672a95d5ecd741f0ced39 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "25910343",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"\n",
"# Learning Rate Schedules\n",
"\n",
"Setting the learning rate for stochastic gradient descent (SGD) is crucially important when training neural networks because it controls both the speed of convergence and the ultimate performance of the network. One of the simplest learning rate strategies is to have a fixed learning rate throughout the training process. Choosing a small learning rate allows the optimizer find good solutions, but this comes at the expense of limiting the initial speed of convergence. Changing the learning rate over time can overcome this tradeoff.\n",
"\n",
"Schedules define how the learning rate changes over time and are typically specified for each epoch or iteration (i.e. batch) of training. Schedules differ from adaptive methods (such as AdaDelta and Adam) because they:\n",
"\n",
"* change the global learning rate for the optimizer, rather than parameter-wise learning rates\n",
"* don't take feedback from the training process and are specified beforehand\n",
"\n",
"In this tutorial, we visualize the schedules defined in `mx.lr_scheduler`, show how to implement custom schedules and see an example of using a schedule while training models. Since schedules are passed to `mx.optimizer.Optimizer` classes, these methods work with both Module and Gluon APIs."
]
},
{
"cell_type": "markdown",
"id": "bebae10e",
"metadata": {},
"source": [
"```python\n",
"from __future__ import print_function\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"import mxnet as mx\n",
"from mxnet.gluon import nn\n",
"from mxnet.gluon.data.vision import transforms\n",
"import numpy as np\n",
"%matplotlib inline\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "167bc156",
"metadata": {},
"source": [
"```python\n",
"def plot_schedule(schedule_fn, iterations=1500):\n",
" # Iteration count starting at 1\n",
" iterations = [i+1 for i in range(iterations)]\n",
" lrs = [schedule_fn(i) for i in iterations]\n",
" plt.scatter(iterations, lrs)\n",
" plt.xlabel(\"Iteration\")\n",
" plt.ylabel(\"Learning Rate\")\n",
" plt.show()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "97215941",
"metadata": {},
"source": [
"## Schedules\n",
"\n",
"In this section, we take a look at the schedules in `mx.lr_scheduler`. All of these schedules define the learning rate for a given iteration, and it is expected that iterations start at 1 rather than 0. So to find the learning rate for the 100th iteration, you can call `schedule(100)`.\n",
"\n",
"### Stepwise Decay Schedule\n",
"\n",
"One of the most commonly used learning rate schedules is called stepwise decay, where the learning rate is reduced by a factor at certain intervals. MXNet implements a `FactorScheduler` for equally spaced intervals, and `MultiFactorScheduler` for greater control. We start with an example of halving the learning rate every 250 iterations. More precisely, the learning rate will be multiplied by `factor` _after_ the `step` index and multiples thereafter. So in the example below the learning rate of the 250th iteration will be 1 and the 251st iteration will be 0.5."
]
},
{
"cell_type": "markdown",
"id": "0d28b7e5",
"metadata": {},
"source": [
"```python\n",
"schedule = mx.lr_scheduler.FactorScheduler(step=250, factor=0.5)\n",
"schedule.base_lr = 1\n",
"plot_schedule(schedule)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "81bd94d6",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/factor.png) <!--notebook-skip-line-->\n",
"\n",
"\n",
"Note: the `base_lr` is used to determine the initial learning rate. It takes a default value of 0.01 since we inherit from `mx.lr_scheduler.LRScheduler`, but it can be set as a property of the schedule. We will see later in this tutorial that `base_lr` is set automatically when providing the `lr_schedule` to `Optimizer`. Also be aware that the schedules in `mx.lr_scheduler` have state (i.e. counters, etc) so calling the schedule out of order may give unexpected results.\n",
"\n",
"We can define non-uniform intervals with `MultiFactorScheduler` and in the example below we halve the learning rate _after_ the 250th, 750th (i.e. a step length of 500 iterations) and 900th (a step length of 150 iterations). As before, the learning rate of the 250th iteration will be 1 and the 251th iteration will be 0.5."
]
},
{
"cell_type": "markdown",
"id": "ef2f1804",
"metadata": {},
"source": [
"```python\n",
"schedule = mx.lr_scheduler.MultiFactorScheduler(step=[250, 750, 900], factor=0.5)\n",
"schedule.base_lr = 1\n",
"plot_schedule(schedule)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "65d90197",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/multifactor.png) <!--notebook-skip-line-->\n",
"\n",
"\n",
"### Polynomial Schedule\n",
"\n",
"Stepwise schedules and the discontinuities they introduce may sometimes lead to instability in the optimization, so in some cases smoother schedules are preferred. `PolyScheduler` gives a smooth decay using a polynomial function and reaches a learning rate of 0 after `max_update` iterations. In the example below, we have a quadratic function (`pwr=2`) that falls from 0.998 at iteration 1 to 0 at iteration 1000. After this the learning rate stays at 0, so nothing will be learnt from `max_update` iterations onwards."
]
},
{
"cell_type": "markdown",
"id": "41852fcb",
"metadata": {},
"source": [
"```python\n",
"schedule = mx.lr_scheduler.PolyScheduler(max_update=1000, base_lr=1, pwr=2)\n",
"plot_schedule(schedule)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "ccc5a065",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/polynomial.png) <!--notebook-skip-line-->\n",
"\n",
"\n",
"Note: unlike `FactorScheduler`, the `base_lr` is set as an argument when instantiating the schedule.\n",
"\n",
"And we don't evaluate at `iteration=0` (to get `base_lr`) since we are working with schedules starting at `iteration=1`.\n",
"\n",
"### Custom Schedules\n",
"\n",
"You can implement your own custom schedule with a function or callable class, that takes an integer denoting the iteration index (starting at 1) and returns a float representing the learning rate to be used for that iteration. We implement the Cosine Annealing Schedule in the example below as a callable class (see `__call__` method)."
]
},
{
"cell_type": "markdown",
"id": "09e4b13a",
"metadata": {},
"source": [
"```python\n",
"class CosineAnnealingSchedule():\n",
" def __init__(self, min_lr, max_lr, cycle_length):\n",
" self.min_lr = min_lr\n",
" self.max_lr = max_lr\n",
" self.cycle_length = cycle_length\n",
" \n",
" def __call__(self, iteration):\n",
" if iteration <= self.cycle_length:\n",
" unit_cycle = (1 + math.cos(iteration * math.pi / self.cycle_length)) / 2\n",
" adjusted_cycle = (unit_cycle * (self.max_lr - self.min_lr)) + self.min_lr\n",
" return adjusted_cycle\n",
" else:\n",
" return self.min_lr\n",
"\n",
"\n",
"schedule = CosineAnnealingSchedule(min_lr=0, max_lr=1, cycle_length=1000)\n",
"plot_schedule(schedule)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "3ebbe83b",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_schedules/cosine.png) <!--notebook-skip-line-->\n",
"\n",
"\n",
"## Using Schedules\n",
"\n",
"While training a simple handwritten digit classifier on the MNIST dataset, we take a look at how to use a learning rate schedule during training. Our demonstration model is a basic convolutional neural network. We start by preparing our `DataLoader` and defining the network. \n",
"\n",
"As discussed above, the schedule should return a learning rate given an (1-based) iteration index."
]
},
{
"cell_type": "markdown",
"id": "12f29d7e",
"metadata": {},
"source": [
"```python\n",
"# Use GPU if one exists, else use CPU\n",
"ctx = mx.gpu() if mx.context.num_gpus() else mx.cpu()\n",
"\n",
"# MNIST images are 28x28. Total pixels in input layer is 28x28 = 784\n",
"num_inputs = 784\n",
"# Clasify the images into one of the 10 digits\n",
"num_outputs = 10\n",
"# 64 images in a batch\n",
"batch_size = 64\n",
"\n",
"# Load the training data\n",
"train_dataset = mx.gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor())\n",
"train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=5)\n",
"\n",
"# Build a simple convolutional network\n",
"def build_cnn():\n",
" net = nn.HybridSequential()\n",
" with net.name_scope():\n",
" # First convolution\n",
" net.add(nn.Conv2D(channels=10, kernel_size=5, activation='relu'))\n",
" net.add(nn.MaxPool2D(pool_size=2, strides=2))\n",
" # Second convolution\n",
" net.add(nn.Conv2D(channels=20, kernel_size=5, activation='relu'))\n",
" net.add(nn.MaxPool2D(pool_size=2, strides=2))\n",
" # Flatten the output before the fully connected layers\n",
" net.add(nn.Flatten())\n",
" # First fully connected layers with 512 neurons\n",
" net.add(nn.Dense(512, activation=\"relu\"))\n",
" # Second fully connected layer with as many neurons as the number of classes\n",
" net.add(nn.Dense(num_outputs))\n",
" return net\n",
" \n",
"net = build_cnn()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "69fa60e1",
"metadata": {},
"source": [
"We then initialize our network (technically deferred until we pass the first batch) and define the loss."
]
},
{
"cell_type": "markdown",
"id": "9a041454",
"metadata": {},
"source": [
"```python\n",
"# Initialize the parameters with Xavier initializer\n",
"net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)\n",
"# Use cross entropy loss\n",
"softmax_cross_entropy = mx.gluon.loss.SoftmaxCrossEntropyLoss()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "b4373c73",
"metadata": {},
"source": [
"We're now ready to create our schedule, and in this example we opt for a stepwise decay schedule using `MultiFactorScheduler`. Since we're only training a demonstration model for a limited number of epochs (10 in total) we will exaggerate the schedule and drop the learning rate by 90% after the 4th, 7th and 9th epochs. We call these steps, and the drop occurs _after_ the step index. Schedules are defined for iterations (i.e. training batches), so we must represent our steps in iterations too."
]
},
{
"cell_type": "markdown",
"id": "db522c29",
"metadata": {},
"source": [
"```python\n",
"steps_epochs = [4, 7, 9]\n",
"# assuming we keep partial batches, see `last_batch` parameter of DataLoader\n",
"iterations_per_epoch = math.ceil(len(train_dataset) / batch_size)\n",
"# iterations just before starts of epochs (iterations are 1-indexed)\n",
"steps_iterations = [s*iterations_per_epoch for s in steps_epochs]\n",
"print(\"Learning rate drops after iterations: {}\".format(steps_iterations))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "04705941",
"metadata": {},
"source": [
"```\n",
"Learning rate drops after iterations: [3752, 6566, 8442]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "dca27df8",
"metadata": {},
"source": [
"```python\n",
"schedule = mx.lr_scheduler.MultiFactorScheduler(step=steps_iterations, factor=0.1)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "4b9bf02d",
"metadata": {},
"source": [
"**We create our `Optimizer` and pass the schedule via the `lr_scheduler` parameter.** In this example we're using Stochastic Gradient Descent."
]
},
{
"cell_type": "markdown",
"id": "5c89e08b",
"metadata": {},
"source": [
"```python\n",
"sgd_optimizer = mx.optimizer.SGD(learning_rate=0.03, lr_scheduler=schedule)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "073e9450",
"metadata": {},
"source": [
"And we use this optimizer (with schedule) in our `Trainer` and train for 10 epochs. Alternatively, we could have set the `optimizer` to the string `sgd`, and pass a dictionary of the optimizer parameters directly to the trainer using `optimizer_params`."
]
},
{
"cell_type": "markdown",
"id": "7afb9f65",
"metadata": {},
"source": [
"```python\n",
"trainer = mx.gluon.Trainer(params=net.collect_params(), optimizer=sgd_optimizer)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "43007cc0",
"metadata": {},
"source": [
"```python\n",
"num_epochs = 10\n",
"# epoch and batch counts starting at 1\n",
"for epoch in range(1, num_epochs+1):\n",
" # Iterate through the images and labels in the training data\n",
" for batch_num, (data, label) in enumerate(train_dataloader, start=1):\n",
" # get the images and labels\n",
" data = data.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" # Ask autograd to record the forward pass\n",
" with mx.autograd.record():\n",
" # Run the forward pass\n",
" output = net(data)\n",
" # Compute the loss\n",
" loss = softmax_cross_entropy(output, label)\n",
" # Compute gradients\n",
" loss.backward()\n",
" # Update parameters\n",
" trainer.step(data.shape[0])\n",
"\n",
" # Show loss and learning rate after first iteration of epoch\n",
" if batch_num == 1:\n",
" curr_loss = mx.nd.mean(loss).asscalar()\n",
" curr_lr = trainer.learning_rate\n",
" print(\"Epoch: %d; Batch %d; Loss %f; LR %f\" % (epoch, batch_num, curr_loss, curr_lr))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "43ca8677",
"metadata": {},
"source": [
"Epoch: 1; Batch 1; Loss 2.304071; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 2; Batch 1; Loss 0.059640; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 3; Batch 1; Loss 0.072601; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 4; Batch 1; Loss 0.042228; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 5; Batch 1; Loss 0.025745; LR 0.003000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 6; Batch 1; Loss 0.027391; LR 0.003000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 7; Batch 1; Loss 0.048237; LR 0.003000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 8; Batch 1; Loss 0.024213; LR 0.000300 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 9; Batch 1; Loss 0.008892; LR 0.000300 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 10; Batch 1; Loss 0.006875; LR 0.000030 <!--notebook-skip-line-->\n",
"\n",
"\n",
"We see that the learning rate starts at 0.03, and falls to 0.00003 by the end of training as per the schedule we defined.\n",
"\n",
"### Manually setting the learning rate: Gluon API only\n",
"\n",
"When using the method above you don't need to manually keep track of iteration count and set the learning rate, so this is the recommended approach for most cases. Sometimes you might want more fine-grained control over setting the learning rate though, so Gluon's `Trainer` provides the `set_learning_rate` method for this.\n",
"\n",
"We replicate the example above, but now keep track of the `iteration_idx`, call the schedule and set the learning rate appropriately using `set_learning_rate`. We also use `schedule.base_lr` to set the initial learning rate for the schedule since we are calling the schedule directly and not using it as part of the `Optimizer`."
]
},
{
"cell_type": "markdown",
"id": "95e91672",
"metadata": {},
"source": [
"```python\n",
"net = build_cnn()\n",
"net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)\n",
"\n",
"schedule = mx.lr_scheduler.MultiFactorScheduler(step=steps_iterations, factor=0.1)\n",
"schedule.base_lr = 0.03\n",
"sgd_optimizer = mx.optimizer.SGD()\n",
"trainer = mx.gluon.Trainer(params=net.collect_params(), optimizer=sgd_optimizer)\n",
"\n",
"iteration_idx = 1\n",
"num_epochs = 10\n",
"# epoch and batch counts starting at 1\n",
"for epoch in range(1, num_epochs + 1):\n",
" # Iterate through the images and labels in the training data\n",
" for batch_num, (data, label) in enumerate(train_dataloader, start=1):\n",
" # get the images and labels\n",
" data = data.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" # Ask autograd to record the forward pass\n",
" with mx.autograd.record():\n",
" # Run the forward pass\n",
" output = net(data)\n",
" # Compute the loss\n",
" loss = softmax_cross_entropy(output, label)\n",
" # Compute gradients\n",
" loss.backward()\n",
" # Update the learning rate\n",
" lr = schedule(iteration_idx)\n",
" trainer.set_learning_rate(lr)\n",
" # Update parameters\n",
" trainer.step(data.shape[0])\n",
" # Show loss and learning rate after first iteration of epoch\n",
" if batch_num == 1:\n",
" curr_loss = mx.nd.mean(loss).asscalar()\n",
" curr_lr = trainer.learning_rate\n",
" print(\"Epoch: %d; Batch %d; Loss %f; LR %f\" % (epoch, batch_num, curr_loss, curr_lr))\n",
" iteration_idx += 1\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "2c5cc46a",
"metadata": {},
"source": [
"Epoch: 1; Batch 1; Loss 2.334119; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 2; Batch 1; Loss 0.178930; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 3; Batch 1; Loss 0.142640; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 4; Batch 1; Loss 0.041116; LR 0.030000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 5; Batch 1; Loss 0.051049; LR 0.003000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 6; Batch 1; Loss 0.027170; LR 0.003000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 7; Batch 1; Loss 0.083776; LR 0.003000 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 8; Batch 1; Loss 0.082553; LR 0.000300 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 9; Batch 1; Loss 0.027984; LR 0.000300 <!--notebook-skip-line-->\n",
"\n",
"Epoch: 10; Batch 1; Loss 0.030896; LR 0.000030 <!--notebook-skip-line-->\n",
"\n",
"\n",
"Once again, we see the learning rate start at 0.03, and fall to 0.00003 by the end of training as per the schedule we defined.\n",
"\n",
"## Advanced Schedules\n",
"\n",
"We have a related tutorial on Advanced Learning Rate Schedules that shows reference implementations of schedules that give state-of-the-art results. We look at cyclical schedules applied to a variety of cycle shapes, and many other techniques such as warm-up and cool-down.\n",
"\n",
"<!-- INSERT SOURCE DOWNLOAD BUTTONS -->"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}