blob: c7255663c42cf66aba77e51c7a81a81b811f19e9 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "13dc51f4",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Trainer\n",
"\n",
"Training a neural network model consists of iteratively performing three simple steps.\n",
"\n",
"The first step is the forward step which computes the loss. In MXNet Gluon, this first step is achieved by doing a forward pass by calling `net.forward(X)` or simply `net(X)` and then calling the loss function with the result of the forward pass and the labels. For example `l = loss_fn(net(X), y)`.\n",
"\n",
"The second step is the backward step which computes the gradient of the loss with respect to the parameters. In Gluon, this step is achieved by doing the first step in an [autograd.record()](../../../../api/autograd/index.rst#mxnet.autograd.record) scope to record the computations needed to calculate the loss, and then calling `l.backward()` to compute the gradient of the loss with respect to the parameters.\n",
"\n",
"The final step is to update the neural network model parameters using an optimization algorithm. In Gluon, this step is performed by the [gluon.Trainer](../../../../api/gluon/trainer.rst#mxnet.gluon.Trainer) and is the subject of this guide. When creating a Gluon `Trainer` you must provide a collection of parameters that need to be learnt. You also provide an `Optimizer` that will be used to update the parameters every training iteration when `trainer.step` is called.\n",
"\n",
"## Basic Usage\n",
"\n",
"### Network and Trainer\n",
"\n",
"To illustrate how to use the Gluon `Trainer` we will create a simple perceptron model and create a `Trainer ` instance using the perceptron model parameters and a simple optimizer - `sgd` with learning rate as 1."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2aef773c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[04:51:16] /work/mxnet/src/storage/storage.cc:202: Using Pooled (Naive) StorageManager for CPU\n"
]
}
],
"source": [
"from mxnet import np, autograd, optimizer, gluon\n",
"\n",
"net = gluon.nn.Dense(1)\n",
"net.initialize()\n",
"\n",
"trainer = gluon.Trainer(net.collect_params(),\n",
" optimizer='sgd', optimizer_params={'learning_rate':1})\n"
]
},
{
"cell_type": "markdown",
"id": "497f452c",
"metadata": {},
"source": [
"### Forward and Backward Pass\n",
"\n",
"Before we can use the `trainer` to update model parameters, we must first run the forward and backward passes. Here we implement a function to compute the first two steps (forward step and backward step) of training the perceptron on a random dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7f65436c",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 8\n",
"X = np.random.uniform(size=(batch_size, 4))\n",
"y = np.random.uniform(size=(batch_size,))\n",
"\n",
"loss = gluon.loss.L2Loss()\n",
"\n",
"def forward_backward():\n",
" with autograd.record():\n",
" l = loss(net(X), y)\n",
" l.backward()\n",
"\n",
"forward_backward()"
]
},
{
"cell_type": "markdown",
"id": "c8887cec",
"metadata": {},
"source": [
"**Warning**: It is extremely important that the gradients of the loss function with respect to your model parameters are computed before running `trainer step`. A common way to introduce bugs to your model training code is to omit the `loss.backward()`before the update step.\n",
"\n",
"\n",
"\n",
"Before updating, let's check the current network parameters."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7625f6a5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.00427801 0.01983307 0.05400988 -0.03179503]]\n"
]
}
],
"source": [
"curr_weight = net.weight.data().copy()\n",
"print(curr_weight)"
]
},
{
"cell_type": "markdown",
"id": "4ad139b2",
"metadata": {},
"source": [
"### `Trainer` step\n",
"\n",
"Now we will call the `step` method to perform one update. We provide the `batch_size` as an argument to normalize the size of the gradients and make it independent of the batch size. Otherwise we'd get larger gradients with larger batch sizes. We can see the network parameters have now changed."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5acb4e2f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.14192605 0.27255496 0.27619216 0.12496824]]\n"
]
}
],
"source": [
"trainer.step(batch_size)\n",
"print(net.weight.data())"
]
},
{
"cell_type": "markdown",
"id": "f3b5b3fd",
"metadata": {},
"source": [
"Since we used plain SGD, the update rule is $w = w - \\eta/b \\nabla \\ell$, where $b$ is the batch size and $\\nabla\\ell$ is the gradient of the loss function with respect to the weights and $\\eta$ is the learning rate.\n",
"\n",
"We can verify it by running the following code snippet which is explicitly performing the SGD update."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d4298d9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.14192605 0.27255496 0.27619216 0.12496824]]\n"
]
}
],
"source": [
"print(curr_weight - net.weight.grad() * 1 / batch_size)"
]
},
{
"cell_type": "markdown",
"id": "8becb733",
"metadata": {},
"source": [
"## Advanced Usage\n",
"\n",
"### Using Optimizer Instance\n",
"\n",
"In the previous example, we use the string argument `sgd` to select the optimization method, and `optimizer_params` to specify the optimization method arguments.\n",
"\n",
"All pre-defined optimization methods can be passed in this way and the complete list of implemented optimizers is provided in the [mxnet.optimizer](../../../../api/optimizer/index.rst) module.\n",
"\n",
"However we can also pass an optimizer instance directly to the `Trainer` constructor.\n",
"\n",
"For example:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f427b0e1",
"metadata": {},
"outputs": [],
"source": [
"optim = optimizer.Adam(learning_rate = 1)\n",
"trainer = gluon.Trainer(net.collect_params(), optim)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "902f96f9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.8580791 , -0.72745013, -0.7238127 , -0.87503594]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forward_backward()\n",
"trainer.step(batch_size)\n",
"net.weight.data()"
]
},
{
"cell_type": "markdown",
"id": "7874c8cf",
"metadata": {},
"source": [
"For reference and implementation details about each optimizer, please refer to the [guide](../../optimizer/index.ipynb) and [API doc](../../../../api/optimizer/index.rst) for the `optimizer` module.\n",
"\n",
"### KVStore Options\n",
"\n",
"The `Trainer` constructor also accepts the following keyword arguments for :\n",
"\n",
"- `kvstore` – how key value store should be created for multi-gpu and distributed training. Check out [mxnet.kvstore.KVStore](../../../../api/kvstore/index.rst) for more information. String options are any of the following ['local', 'device', 'dist_device_sync', 'dist_device_async'].\n",
"- `compression_params` – Specifies type of gradient compression and additional arguments depending on the type of compression being used. See [mxnet.KVStore.set_gradient_compression_method](../../../../api/kvstore/generated/mxnet.kvstore.KVStore.rst) for more details on gradient compression.\n",
"- `update_on_kvstore` – Whether to perform parameter updates on KVStore. If None, then the `Trainer` instance will choose the more suitable option depending on the type of KVStore.\n",
"\n",
"### Changing the Learning Rate\n",
"\n",
"We set the initial learning rate when creating a trainer by passing the learning rate as an `optimizer_param`. However, sometimes we may need to change the learning rate during training, for example when doing an explicit learning rate warmup schedule. The trainer instance provides an easy way to achieve this.\n",
"\n",
"The current training rate can be accessed through the `learning_rate` attribute."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dfdd2858",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.learning_rate"
]
},
{
"cell_type": "markdown",
"id": "f5ab2cca",
"metadata": {},
"source": [
"We can change it through the `set_learning_rate` method."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "64118b86",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.1"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.set_learning_rate(0.1)\n",
"trainer.learning_rate"
]
},
{
"cell_type": "markdown",
"id": "686378b7",
"metadata": {},
"source": [
"In addition, there are multiple pre-defined learning rate scheduling methods that are already implemented in the [mxnet.lr_scheduler](../../../../api/lr_scheduler/index.rst) module. The learning rate schedulers can be incorporated into your trainer by passing them in as an `optimizer_param` entry. Please refer to the [LR scheduler guide](./learning_rates/learning_rate_schedules.ipynb) to learn more.\n",
"\n",
"\n",
"\n",
"## Summary\n",
"\n",
"* The MXNet Gluon `Trainer` API is used to update the parameters of a network with a particular optimization algorithm.\n",
"* After the forward and backward pass, the model update step is done in Gluon using `trainer.step()`.\n",
"* A Gluon `Trainer` can be instantiated by passing in the name of the optimizer to use and the `optimizer_params` for that optimizer or alternatively by passing in an instance of `mxnet.optimizer.Optimizer`.\n",
"* You can change the learning rate for a Gluon `Trainer` by setting the member variable but Gluon also provides a module for learning rate scheduling.\n",
"\n",
"\n",
"\n",
"## Next Steps\n",
"\n",
"While optimization and optimizers play a significant role in deep learning model training, there are still other important components to model training. Here are a few suggestions about where to look next.\n",
"\n",
"* The [Optimizer API](../../../../api/optimizer/index.rst) and [optimizer guide](../../optimizer/index.ipynb) have information about all the different optimizers implemented in MXNet and their update steps. The [Dive into Deep Learning](http://d2l.ai/chapter_optimization/index.html) book also has a chapter dedicated to optimization methods and explains various key optimizers in great detail.\n",
"\n",
"- Take a look at the [guide to parameter initialization](../blocks/init.ipynb) in MXNet to learn about what initialization schemes are already implemented, and how to implement your custom initialization schemes.\n",
"- Also check out this [guide on parameter management](../blocks/parameters.ipynb) to learn about how to manage model parameters in gluon.\n",
"- Make sure to take a look at the [guide to scheduling learning rates](./learning_rates/learning_rate_schedules.ipynb) to learn how to create learning rate schedules to make your training converge faster.\n",
"- Finally take a look at the [KVStore API](../../../../api/kvstore/index.rst) to learn how parameter values are synchronized over multiple devices."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}