blob: 747454e1f95bb31c5c7e35976fc4c54f2ce081c8 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "2399c188",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Train the neural network\n",
"\n",
"In this section, we will discuss how to train the previously defined network with data. We first import the libraries. The new ones are `mxnet.init` for more weight initialization methods, the `datasets` and `transforms` to load and transform computer vision datasets, `matplotlib` for drawing, and `time` for benchmarking."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2788ab44",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"# Uncomment the following line if matplotlib is not installed.\n",
"# !pip install matplotlib\n",
"\n",
"from mxnet import nd, gluon, init, autograd\n",
"from mxnet.gluon import nn\n",
"from mxnet.gluon.data.vision import datasets, transforms\n",
"from IPython import display\n",
"import matplotlib.pyplot as plt\n",
"import time"
]
},
{
"cell_type": "markdown",
"id": "43df4f99",
"metadata": {},
"source": [
"## Get data\n",
"\n",
"The handwritten digit MNIST dataset is one of the most commonly used datasets in deep learning. But it is too simple to get a 99% accuracy. Here we use a similar but slightly more complicated dataset called FashionMNIST. The goal is no longer to classify numbers, but clothing types instead.\n",
"\n",
"The dataset can be automatically downloaded through Gluon's `data.vision.datasets` module. The following code downloads the training dataset and shows the first example."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7eb83989",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"mnist_train = datasets.FashionMNIST(train=True)\n",
"X, y = mnist_train[0]\n",
"('X shape: ', X.shape, 'X dtype', X.dtype, 'y:', y)"
]
},
{
"cell_type": "markdown",
"id": "3f3adee5",
"metadata": {},
"source": [
"Each example in this dataset is a $28\\times 28$ size grey image, which is presented as NDArray with the shape format of `(height, width, channel)`. The label is a `numpy` scalar.\n",
"\n",
"Next, we visualize the first six examples."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cbd66525",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
" 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
"X, y = mnist_train[0:10]\n",
"# plot images\n",
"display.set_matplotlib_formats('svg')\n",
"_, figs = plt.subplots(1, X.shape[0], figsize=(15, 15))\n",
"for f,x,yi in zip(figs, X,y):\n",
" # 3D->2D by removing the last channel dim\n",
" f.imshow(x.reshape((28,28)).asnumpy())\n",
" ax = f.axes\n",
" ax.set_title(text_labels[int(yi)])\n",
" ax.title.set_fontsize(14)\n",
" ax.get_xaxis().set_visible(False)\n",
" ax.get_yaxis().set_visible(False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7a788520",
"metadata": {},
"source": [
"In order to feed data into a Gluon model, we need to convert the images to the `(channel, height, width)` format with a floating point data type. It can be done by `transforms.ToTensor`. In addition, we normalize all pixel values with `transforms.Normalize` with the real mean 0.13 and standard deviation 0.31. We chain these two transforms together and apply it to the first element of the data pair, namely the images."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "86e15930",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"transformer = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(0.13, 0.31)])\n",
"mnist_train = mnist_train.transform_first(transformer)"
]
},
{
"cell_type": "markdown",
"id": "3c633836",
"metadata": {},
"source": [
"`FashionMNIST` is a subclass of `gluon.data.Dataset`, which defines how to get the `i`-th example. In order to use it in training, we need to get a (randomized) batch of examples. It can be easily done by `gluon.data.DataLoader`. Here we use four works to process data in parallel, which is often necessary especially for complex data transforms."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "15b4fd37",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "5"
}
},
"outputs": [],
"source": [
"batch_size = 256\n",
"train_data = gluon.data.DataLoader(\n",
" mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)"
]
},
{
"cell_type": "markdown",
"id": "36d4c47d",
"metadata": {},
"source": [
"The returned `train_data` is an iterable object that yields batches of images and labels pairs."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "be64a86b",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [],
"source": [
"for data, label in train_data:\n",
" print(data.shape, label.shape)\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "b2f351f5",
"metadata": {},
"source": [
"Finally, we create a validation dataset and data loader."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e264b25e",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
}
},
"outputs": [],
"source": [
"mnist_valid = gluon.data.vision.FashionMNIST(train=False)\n",
"valid_data = gluon.data.DataLoader(\n",
" mnist_valid.transform_first(transformer),\n",
"\tbatch_size=batch_size, num_workers=4)"
]
},
{
"cell_type": "markdown",
"id": "3cb97a5f",
"metadata": {},
"source": [
"## Define the model\n",
"\n",
"We reimplement the same LeNet introduced before. One difference here is that we changed the weight initialization method to `Xavier`, which is a popular choice for deep convolutional neural networks."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6c0b22ea",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "8"
}
},
"outputs": [],
"source": [
"net = nn.Sequential()\n",
"net.add(nn.Conv2D(channels=6, kernel_size=5, activation='relu'),\n",
" nn.MaxPool2D(pool_size=2, strides=2),\n",
" nn.Conv2D(channels=16, kernel_size=3, activation='relu'),\n",
" nn.MaxPool2D(pool_size=2, strides=2),\n",
" nn.Flatten(),\n",
" nn.Dense(120, activation=\"relu\"),\n",
" nn.Dense(84, activation=\"relu\"),\n",
" nn.Dense(10))\n",
"net.initialize(init=init.Xavier())"
]
},
{
"cell_type": "markdown",
"id": "bdaa274d",
"metadata": {},
"source": [
"Besides the neural network, we need to define the loss function and optimization method for training. We will use standard softmax cross entropy loss for classification problems. It first performs softmax on the output to obtain the predicted probability, and then compares the label with the cross entropy."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1553e2fe",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "9"
}
},
"outputs": [],
"source": [
"softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"id": "f39f7384",
"metadata": {},
"source": [
"The optimization method we pick is the standard stochastic gradient descent with constant learning rate of 0.1."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "06330482",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "10"
}
},
"outputs": [],
"source": [
"trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})"
]
},
{
"cell_type": "markdown",
"id": "4d1a7023",
"metadata": {},
"source": [
"The `trainer` is created with all parameters (both weights and gradients) in `net`. Later on, we only need to call the `step` method to update its weights.\n",
"\n",
"## Train\n",
"\n",
"We create an auxiliary function to calculate the model accuracy."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "44470d28",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "11"
}
},
"outputs": [],
"source": [
"def acc(output, label):\n",
" # output: (batch, num_output) float32 ndarray\n",
" # label: (batch, ) int32 ndarray\n",
" return (output.argmax(axis=1) ==\n",
" label.astype('float32')).mean().asscalar()"
]
},
{
"cell_type": "markdown",
"id": "8b5e39a9",
"metadata": {},
"source": [
"Now we can implement the complete training loop."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ce80a914",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "12"
}
},
"outputs": [],
"source": [
"for epoch in range(10):\n",
" train_loss, train_acc, valid_acc = 0., 0., 0.\n",
" tic = time.time()\n",
" for data, label in train_data:\n",
" # forward + backward\n",
" with autograd.record():\n",
" output = net(data)\n",
" loss = softmax_cross_entropy(output, label)\n",
" loss.backward()\n",
" # update parameters\n",
" trainer.step(batch_size)\n",
" # calculate training metrics\n",
" train_loss += loss.mean().asscalar()\n",
" train_acc += acc(output, label)\n",
" # calculate validation accuracy\n",
" for data, label in valid_data:\n",
" valid_acc += acc(net(data), label)\n",
" print(\"Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec\" % (\n",
" epoch, train_loss/len(train_data), train_acc/len(train_data),\n",
" valid_acc/len(valid_data), time.time()-tic))"
]
},
{
"cell_type": "markdown",
"id": "aa12e732",
"metadata": {},
"source": [
"## Save the model\n",
"\n",
"Finally, we save the trained parameters onto disk, so that we can use them later."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1668712c",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "13"
}
},
"outputs": [],
"source": [
"net.save_parameters('net.params')"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}