blob: 476839385193fc1db7b48c5b2d7fa0df3b2fcfc3 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "896ea900",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Step 6: Train a Neural Network\n",
"\n",
"Now that you have seen all the necessary components for creating a neural network, you are\n",
"now ready to put all the pieces together and train a model end to end.\n",
"\n",
"## 1. Data preparation\n",
"\n",
"The typical process for creating and training a model starts with loading and\n",
"preparing the datasets. For this Network you will use a [dataset of leaf\n",
"images](https://data.mendeley.com/datasets/hb74ynkjcn/1) that consists of healthy\n",
"and diseased examples of leafs from twelve different plant species. To get this\n",
"dataset you have to download and extract it with the following commands."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f0a533a",
"metadata": {},
"outputs": [],
"source": [
"# Import all the necessary libraries to train\n",
"import time\n",
"import os\n",
"import zipfile\n",
"\n",
"import mxnet as mx\n",
"from mxnet import np, npx, gluon, init, autograd\n",
"from mxnet.gluon import nn\n",
"from mxnet.gluon.data.vision import transforms\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from prepare_dataset import process_dataset #utility code to rearrange the data\n",
"\n",
"mx.random.seed(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3434543b",
"metadata": {},
"outputs": [],
"source": [
"# Download dataset\n",
"url = 'https://md-datasets-cache-zipfiles-prod.s3.eu-west-1.amazonaws.com/hb74ynkjcn-1.zip'\n",
"zip_file_path = mx.gluon.utils.download(url)\n",
"\n",
"os.makedirs('plants', exist_ok=True)\n",
"\n",
"with zipfile.ZipFile(zip_file_path, 'r') as zf:\n",
" zf.extractall('plants')\n",
"\n",
"os.remove(zip_file_path)"
]
},
{
"cell_type": "markdown",
"id": "1fde2966",
"metadata": {},
"source": [
"#### Data inspection\n",
"\n",
"If you take a look at the dataset you find the following structure for the directories:"
]
},
{
"cell_type": "markdown",
"id": "f074d4e1",
"metadata": {},
"source": [
"```\n",
"plants\n",
"|-- Alstonia Scholaris (P2)\n",
"|-- Arjun (P1)\n",
"|-- Bael (P4)\n",
" |-- diseased\n",
" |-- 0016_0001.JPG\n",
" |-- .\n",
" |-- .\n",
" |-- .\n",
" |-- 0016_0118.JPG\n",
"|-- .\n",
"|-- .\n",
"|-- .\n",
"|-- Mango (P0)\n",
" |-- diseased\n",
" |-- healthy\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "b8364857",
"metadata": {},
"source": [
"Each plant species has its own directory, for each of those directories you might\n",
"find subdirectories with examples of diseased leaves, healthy\n",
"leaves, or both. With this dataset you can formulate different classification\n",
"problems; for example, you can create a multi-class classifier that determines\n",
"the species of a plant based on the leaves; you can instead create a binary\n",
"classifier that tells you whether the plant is healthy or diseased. Additionally, you can create\n",
"a multi-class, multi-label classifier that tells you both: what species a\n",
"plant is and whether the plant is diseased or healthy. In this example you will stick to\n",
"the simplest classification question, which is whether a plant is healthy or not.\n",
"\n",
"To do this, you need to manipulate the dataset in two ways. First, you need to\n",
"combine all images with labels consisting of healthy and diseased, regardless of the species, and then you\n",
"need to split the data into train, validation, and test sets. We prepared a\n",
"small utility script that does this to get the dataset ready for you.\n",
"Once you run this utility code on the data, the structure will be\n",
"already organized in folders containing the right images in each of the classes,\n",
"you can use the `ImageFolderDataset` class to import the images from the file to MXNet."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8af6ab9d",
"metadata": {},
"outputs": [],
"source": [
"# Call the utility function to rearrange the images\n",
"process_dataset('plants')"
]
},
{
"cell_type": "markdown",
"id": "ec6eec3a",
"metadata": {},
"source": [
"The dataset is located in the `datasets` folder and the new structure\n",
"looks like this:"
]
},
{
"cell_type": "markdown",
"id": "f01413a8",
"metadata": {},
"source": [
"```\n",
"datasets\n",
"|-- test\n",
" |-- diseased\n",
" |-- healthy\n",
"|-- train\n",
"|-- validation\n",
" |-- diseased\n",
" |-- healthy\n",
" |-- image1.JPG\n",
" |-- image2.JPG\n",
" |-- .\n",
" |-- .\n",
" |-- .\n",
" |-- imagen.JPG\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "a3abbe80",
"metadata": {},
"source": [
"Now, you need to create three different Dataset objects from the `train`,\n",
"`validation`, and `test` folders, and the `ImageFolderDataset` class takes\n",
"care of inferring the classes from the directory names. If you don't remember\n",
"how the `ImageFolderDataset` works, take a look at [Step 5](5-datasets.md)\n",
"of this course for a deeper description."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24a982b8",
"metadata": {},
"outputs": [],
"source": [
"# Use ImageFolderDataset to create a Dataset object from directory structure\n",
"train_dataset = gluon.data.vision.ImageFolderDataset('./datasets/train')\n",
"val_dataset = gluon.data.vision.ImageFolderDataset('./datasets/validation')\n",
"test_dataset = gluon.data.vision.ImageFolderDataset('./datasets/test')"
]
},
{
"cell_type": "markdown",
"id": "f98f8bd8",
"metadata": {},
"source": [
"The result from this operation is a different Dataset object for each folder.\n",
"These objects hold a collection of images and labels and as such they can be\n",
"indexed, to get the $i$-th element from the dataset. The $i$-th element is a\n",
"tuple with two objects, the first object of the tuple is the image in array\n",
"form and the second is the corresponding label for that image."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ad108ed",
"metadata": {},
"outputs": [],
"source": [
"sample_idx = 888 # choose a random sample\n",
"sample = train_dataset[sample_idx]\n",
"data = sample[0]\n",
"label = sample[1]\n",
"\n",
"plt.imshow(data.asnumpy())\n",
"print(f\"Data type: {data.dtype}\")\n",
"print(f\"Label: {label}\")\n",
"print(f\"Label description: {train_dataset.synsets[label]}\")\n",
"print(f\"Image shape: {data.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "2c8df173",
"metadata": {},
"source": [
"As you can see from the plot, the image size is very large 4000 x 6000 pixels.\n",
"Usually, you downsize images before passing them to a neural network to reduce the training time.\n",
"It is also customary to make slight modifications to the images to improve generalization. That is why you add\n",
"transformations to the data in a process called Data Augmentation.\n",
"\n",
"You can augment data in MXNet using `transforms`. For a complete list of all\n",
"the available transformations in MXNet check out\n",
"[available transforms](../../../api/gluon/data/vision/transforms/index.rst).\n",
"It is very common to use more than one transform per image, and it is also\n",
"common to process transforms sequentially. To this end, you can use the `transforms.Compose` class.\n",
"This class is very useful to create a transformation pipeline for your images.\n",
"\n",
"You have to compose two different transformation pipelines, one for training\n",
"and the other one for validating and testing. This is because each pipeline\n",
"serves different pursposes. You need to downsize, convert to tensor and normalize\n",
"images across all the different datsets; however, you typically do not want to randomly flip\n",
"or add color jitter to the validation or test images since you could reduce performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45a84b5c",
"metadata": {},
"outputs": [],
"source": [
"# Import transforms as compose a series of transformations to the images\n",
"from mxnet.gluon.data.vision import transforms\n",
"\n",
"jitter_param = 0.05\n",
"\n",
"# mean and std for normalizing image value in range (0,1)\n",
"mean = [0.485, 0.456, 0.406]\n",
"std = [0.229, 0.224, 0.225]\n",
"\n",
"training_transformer = transforms.Compose([\n",
" transforms.Resize(size=224, keep_ratio=True),\n",
" transforms.CenterCrop(128),\n",
" transforms.RandomFlipLeftRight(),\n",
" transforms.RandomColorJitter(contrast=jitter_param),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)\n",
"])\n",
"\n",
"validation_transformer = transforms.Compose([\n",
" transforms.Resize(size=224, keep_ratio=True),\n",
" transforms.CenterCrop(128),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)\n",
"])"
]
},
{
"cell_type": "markdown",
"id": "88809d0b",
"metadata": {},
"source": [
"With your augmentations ready, you can create the `DataLoaders` to use them. To\n",
"do this the `gluon.data.DataLoader` class comes in handy. You have to pass the dataset with\n",
"the applied transformations (notice the `.transform_first()` method on the datasets)\n",
"to `gluon.data.DataLoader`. Additionally, you need to decide the batch size,\n",
"which is how many images you will be passing to the network,\n",
"and whether you want to shuffle the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "602bd977",
"metadata": {},
"outputs": [],
"source": [
"# Create data loaders\n",
"batch_size = 4\n",
"train_loader = gluon.data.DataLoader(train_dataset.transform_first(training_transformer),\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" try_nopython=True)\n",
"validation_loader = gluon.data.DataLoader(val_dataset.transform_first(validation_transformer),\n",
" batch_size=batch_size,\n",
" try_nopython=True)\n",
"test_loader = gluon.data.DataLoader(test_dataset.transform_first(validation_transformer),\n",
" batch_size=batch_size,\n",
" try_nopython=True)"
]
},
{
"cell_type": "markdown",
"id": "f6231c94",
"metadata": {},
"source": [
"Now, you can inspect the transformations that you made to the images. A prepared\n",
"utility function has been provided for this."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21083577",
"metadata": {},
"outputs": [],
"source": [
"# Function to plot batch\n",
"def show_batch(batch, columns=4, fig_size=(9, 5), pad=1):\n",
" labels = batch[1].asnumpy()\n",
" batch = batch[0] / 2 + 0.5 # unnormalize\n",
" batch = np.clip(batch.asnumpy(), 0, 1) # clip values\n",
" size = batch.shape[0]\n",
" rows = int(size / columns)\n",
" fig, axes = plt.subplots(rows, columns, figsize=fig_size)\n",
" for ax, img, label in zip(axes.flatten(), batch, labels):\n",
" ax.imshow(np.transpose(img, (1, 2, 0)))\n",
" ax.set(title=f\"Label: {label}\")\n",
" fig.tight_layout(h_pad=pad, w_pad=pad)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a83344",
"metadata": {},
"outputs": [],
"source": [
"for batch in train_loader:\n",
" a = batch\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f681d5e7",
"metadata": {},
"outputs": [],
"source": [
"show_batch(a)"
]
},
{
"cell_type": "markdown",
"id": "dfa9fb4a",
"metadata": {},
"source": [
"You can see that the original images changed to have different sizes and variations\n",
"in color and lighting. These changes followed the specified transformations you stated\n",
"in the pipeline. You are now ready to go to the next step: **Create the\n",
"architecture**.\n",
"\n",
"## 2. Create Neural Network\n",
"\n",
"Convolutional neural networks are a great tool to capture the spatial\n",
"relationship of pixel values within images, for this reason they have become the\n",
"gold standard for computer vision. In this example you will create a small convolutional neural\n",
"network using what you learned from [Step 2](2-create-nn.md) of this crash course series.\n",
"First, you can set up two functions that will generate the two types of blocks\n",
"you intend to use, the convolution block and the dense block. Then you can create an\n",
"entire network based on these two blocks using a custom class."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4ff0029",
"metadata": {},
"outputs": [],
"source": [
"# The convolutional block has a convolution layer, a max pool layer and a batch normalization layer\n",
"def conv_block(filters, kernel_size=2, stride=2, batch_norm=True):\n",
" conv_block = nn.HybridSequential()\n",
" conv_block.add(nn.Conv2D(channels=filters, kernel_size=kernel_size, activation='relu'),\n",
" nn.MaxPool2D(pool_size=4, strides=stride))\n",
" if batch_norm:\n",
" conv_block.add(nn.BatchNorm())\n",
" return conv_block\n",
"\n",
"# The dense block consists of a dense layer and a dropout layer\n",
"def dense_block(neurons, activation='relu', dropout=0.2):\n",
" dense_block = nn.HybridSequential()\n",
" dense_block.add(nn.Dense(neurons, activation=activation))\n",
" if dropout:\n",
" dense_block.add(nn.Dropout(dropout))\n",
" return dense_block"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15e92664",
"metadata": {},
"outputs": [],
"source": [
"# Create neural network blueprint using the blocks\n",
"class LeafNetwork(nn.HybridBlock):\n",
" def __init__(self):\n",
" super(LeafNetwork, self).__init__()\n",
" self.conv1 = conv_block(32)\n",
" self.conv2 = conv_block(64)\n",
" self.conv3 = conv_block(128)\n",
" self.flatten = nn.Flatten()\n",
" self.dense1 = dense_block(100)\n",
" self.dense2 = dense_block(10)\n",
" self.dense3 = nn.Dense(2)\n",
"\n",
" def forward(self, batch):\n",
" batch = self.conv1(batch)\n",
" batch = self.conv2(batch)\n",
" batch = self.conv3(batch)\n",
" batch = self.flatten(batch)\n",
" batch = self.dense1(batch)\n",
" batch = self.dense2(batch)\n",
" batch = self.dense3(batch)\n",
"\n",
" return batch"
]
},
{
"cell_type": "markdown",
"id": "edf5429f",
"metadata": {},
"source": [
"You have concluded the architecting part of the network, so now you can actually\n",
"build a model from that architecture for training. As you have seen\n",
"previously on [Step 4](4-components.md) of this\n",
"crash course series, to use the network you need to initialize the parameters and\n",
"hybridize the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d374c0cb",
"metadata": {},
"outputs": [],
"source": [
"# Create the model based on the blueprint provided and initialize the parameters\n",
"ctx = mx.cpu()\n",
"\n",
"initializer = mx.initializer.Xavier()\n",
"\n",
"model = LeafNetwork()\n",
"model.initialize(initializer, ctx=ctx)\n",
"model.summary(mx.nd.random.uniform(shape=(4, 3, 128, 128)))\n",
"model.hybridize()"
]
},
{
"cell_type": "markdown",
"id": "648ce7fd",
"metadata": {},
"source": [
"## 3. Choose Optimizer and Loss function\n",
"\n",
"With the network created you can move on to choosing an optimizer and a loss\n",
"function. The network you created uses these components to make an informed decision on how\n",
"to tune the parameters to fit the final objective better. You can use the `gluon.Trainer` class to\n",
"help with optimizing these parameters. The `gluon.Trainer` class needs two things to work\n",
"properly: the parameters needing to be tuned and the optimizer with its\n",
"corresponding hyperparameters. The trainer uses the error reported by the loss\n",
"function to optimize these parameters.\n",
"\n",
"For this particular dataset you will use Stochastic Gradient Descent as the\n",
"optimizer and Cross Entropy as the loss function."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "546e23f8",
"metadata": {},
"outputs": [],
"source": [
"# SGD optimizer\n",
"optimizer = 'sgd'\n",
"\n",
"# Set parameters\n",
"optimizer_params = {'learning_rate': 0.001}\n",
"\n",
"# Define the trainer for the model\n",
"trainer = gluon.Trainer(model.collect_params(), optimizer, optimizer_params)\n",
"\n",
"# Define the loss function\n",
"loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"id": "5100d019",
"metadata": {},
"source": [
"Finally, you have to set up the training loop, and you need to create a function to evaluate the performance of the network on the validation dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd9f6155",
"metadata": {},
"outputs": [],
"source": [
"# Function to return the accuracy for the validation and test set\n",
"def test(val_data):\n",
" acc = gluon.metric.Accuracy()\n",
" for batch in val_data:\n",
" data = batch[0]\n",
" labels = batch[1]\n",
" outputs = model(data)\n",
" acc.update([labels], [outputs])\n",
"\n",
" _, accuracy = acc.get()\n",
" return accuracy"
]
},
{
"cell_type": "markdown",
"id": "44be0b78",
"metadata": {},
"source": [
"## 4. Training Loop\n",
"\n",
"Now that you have everything set up, you can start training your network. This might\n",
"take some time to train depending on the hardware, number of layers, batch size and\n",
"images you use. For this particular case, you will only train for 2 epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79973b1b",
"metadata": {},
"outputs": [],
"source": [
"# Start the training loop\n",
"epochs = 2\n",
"accuracy = gluon.metric.Accuracy()\n",
"log_interval = 5\n",
"\n",
"for epoch in range(epochs):\n",
" tic = time.time()\n",
" btic = time.time()\n",
" accuracy.reset()\n",
"\n",
" for idx, batch in enumerate(train_loader):\n",
" data = batch[0]\n",
" label = batch[1]\n",
" with mx.autograd.record():\n",
" outputs = model(data)\n",
" loss = loss_fn(outputs, label)\n",
" mx.autograd.backward(loss)\n",
" trainer.step(batch_size)\n",
" accuracy.update([label], [outputs])\n",
" if log_interval and (idx + 1) % log_interval == 0:\n",
" _, acc = accuracy.get()\n",
"\n",
" print(f\"\"\"Epoch[{epoch + 1}] Batch[{idx + 1}] Speed: {batch_size / (time.time() - btic)} samples/sec \\\n",
" batch loss = {loss.mean().asscalar()} | accuracy = {acc}\"\"\")\n",
" btic = time.time()\n",
"\n",
" _, acc = accuracy.get()\n",
"\n",
" acc_val = test(validation_loader)\n",
" print(f\"[Epoch {epoch + 1}] training: accuracy={acc}\")\n",
" print(f\"[Epoch {epoch + 1}] time cost: {time.time() - tic}\")\n",
" print(f\"[Epoch {epoch + 1}] validation: validation accuracy={acc_val}\")"
]
},
{
"cell_type": "markdown",
"id": "473cd2e0",
"metadata": {},
"source": [
"## 5. Test on the test set\n",
"\n",
"Now that your network is trained and has reached a decent accuracy, you can\n",
"evaluate the performance on the test set. For that, you can use the `test_loader` data\n",
"loader and the test function you created previously."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25f45637",
"metadata": {},
"outputs": [],
"source": [
"test(test_loader)"
]
},
{
"cell_type": "markdown",
"id": "cbed9eaf",
"metadata": {},
"source": [
"You have a trained network that can confidently discriminate between plants that\n",
"are healthy and the ones that are diseased. You can now start your garden and\n",
"set cameras to automatically detect plants in distress! Or change your classification\n",
"problem to create a model that classify the species of the plants! Either way you\n",
"might be able to impress your botanist friends.\n",
"\n",
"## 6. Save the parameters\n",
"\n",
"If you want to preserve the trained weights of the network you can save the\n",
"parameters in a file. Later, when you want to use the network to make predictions\n",
"you can load the parameters back!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b9659cc",
"metadata": {},
"outputs": [],
"source": [
"# Save parameters in the\n",
"model.save_parameters('leaf_models.params')"
]
},
{
"cell_type": "markdown",
"id": "1392cce3",
"metadata": {},
"source": [
"This is the end of this tutorial, to see how you can speed up the training by\n",
"using GPU hardware continue to the [next tutorial](./7-use-gpus.ipynb)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}