blob: 5359f9240ab5499e61dc0b4faf7fea889f77bbb9 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "34a0fcf7",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"\n",
"# Gluon: from experiment to deployment\n",
"\n",
"## Overview\n",
"MXNet Gluon API comes with a lot of great features, and it can provide you everything you need: from experimentation to deploying the model. In this tutorial, we will walk you through a common use case on how to build a model using gluon, train it on your data, and deploy it for inference.\n",
"This tutorial covers training and inference in Python, please continue to [C++ inference part](/api/cpp/docs/tutorials/cpp_inference) after you finish.\n",
"\n",
"Let's say you need to build a service that provides flower species recognition. A common problem is that you don't have enough data to train a good model. In such cases, a technique called Transfer Learning can be used to make a more robust model.\n",
"In Transfer Learning we make use of a pre-trained model that solves a related task, and was trained on a very large standard dataset, such as ImageNet. ImageNet is from a different domain, but we can utilize the knowledge in this pre-trained model to perform the new task at hand.\n",
"\n",
"Gluon provides State of the Art models for many of the standard tasks such as Classification, Object Detection, Segmentation, etc. In this tutorial we will use the pre-trained model [ResNet50 V2](https://arxiv.org/abs/1603.05027) trained on ImageNet dataset. This model achieves 77.11% top-1 accuracy on ImageNet. We seek to transfer as much knowledge as possible for our task of recognizing different species of flowers.\n",
"\n",
"\n",
"\n",
"\n",
"## Prerequisites\n",
"\n",
"To complete this tutorial, you need:\n",
"\n",
"- [Build MXNet from source](https://mxnet.apache.org/get_started/ubuntu_setup#build-mxnet-from-source) with Python(Gluon) and C++ Packages\n",
"- Learn the basics about Gluon with [A 60-minute Gluon Crash Course](https://gluon-crash-course.mxnet.io/)\n",
"\n",
"\n",
"## The Data\n",
"\n",
"We will use the [Oxford 102 Category Flower Dataset](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/) as an example to show you the steps.\n",
"We have prepared a utility file to help you download and organize your data into train, test, and validation sets. Run the following Python code to download and prepare the data:"
]
},
{
"cell_type": "markdown",
"id": "c969af62",
"metadata": {},
"source": [
"```python\n",
"import mxnet as mx\n",
"data_util_file = \"oxford_102_flower_dataset.py\"\n",
"base_url = \"https://raw.githubusercontent.com/apache/mxnet/master/docs/tutorial_utils/data/{}?raw=true\"\n",
"mx.test_utils.download(base_url.format(data_util_file), fname=data_util_file)\n",
"import oxford_102_flower_dataset\n",
"\n",
"# download and move data to train, test, valid folders\n",
"path = './data'\n",
"oxford_102_flower_dataset.get_data(path)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "12d043f5",
"metadata": {},
"source": [
"Now your data will be organized into train, test, and validation sets, images belong to the same class are moved to the same folder.\n",
"\n",
"## Training using Gluon\n",
"\n",
"### Define Hyper-parameters\n",
"\n",
"Now let's first import necessary packages:"
]
},
{
"cell_type": "markdown",
"id": "8aecd085",
"metadata": {},
"source": [
"```python\n",
"import math\n",
"import os\n",
"import time\n",
"\n",
"from mxnet import autograd\n",
"from mxnet import gluon, init\n",
"from mxnet.gluon import nn\n",
"from mxnet.gluon.data.vision import transforms\n",
"from mxnet.gluon.model_zoo.vision import resnet50_v2\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "1d4a5f6c",
"metadata": {},
"source": [
"Next, we define the hyper-parameters that we will use for fine-tuning. We will use the [MXNet learning rate scheduler](/api/python/docs/tutorials/packages/gluon/training/learning_rates/learning_rate_schedules.html) to adjust learning rates during training.\n",
"Here we set the `epochs` to 1 for quick demonstration, please change to 40 for actual training."
]
},
{
"cell_type": "markdown",
"id": "f1a555fe",
"metadata": {},
"source": [
"```python\n",
"classes = 102\n",
"epochs = 1\n",
"lr = 0.001\n",
"per_device_batch_size = 32\n",
"momentum = 0.9\n",
"wd = 0.0001\n",
"\n",
"lr_factor = 0.75\n",
"# learning rate change at following epochs\n",
"lr_epochs = [10, 20, 30]\n",
"\n",
"num_gpus = mx.context.num_gpus()\n",
"# you can replace num_workers with the number of cores on you device\n",
"num_workers = 8\n",
"ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]\n",
"batch_size = per_device_batch_size * max(num_gpus, 1)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "480e7314",
"metadata": {},
"source": [
"Now we will apply data augmentations on training images. This makes minor alterations on the training images, and our model will consider them as distinct images. This can be very useful for fine-tuning on a relatively small dataset, and it will help improve the model. We can use the Gluon [DataSet API](/api/python/docs/api/gluon/data/index.html#mxnet.gluon.data.Dataset), [DataLoader API](/api/python/docs/api/gluon/data/index.html#mxnet.gluon.data.DataLoader), and [Transform API](/api/python/docs/api/gluon/data/index.html#mxnet.gluon.data.Dataset.transform) to load the images and apply the following data augmentations:\n",
"1. Randomly crop the image and resize it to 224x224\n",
"2. Randomly flip the image horizontally\n",
"3. Randomly jitter color and add noise\n",
"4. Transpose the data from `[height, width, num_channels]` to `[num_channels, height, width]`, and map values from [0, 255] to [0, 1]\n",
"5. Normalize with the mean and standard deviation from the ImageNet dataset.\n",
"\n",
"For validation and inference, we only need to apply step 1, 4, and 5. We also need to save the mean and standard deviation values for [inference using C++](/api/cpp/docs/tutorials/cpp_inference)."
]
},
{
"cell_type": "markdown",
"id": "eef89eee",
"metadata": {},
"source": [
"```python\n",
"jitter_param = 0.4\n",
"lighting_param = 0.1\n",
"\n",
"# mean and std for normalizing image value in range (0,1)\n",
"mean = [0.485, 0.456, 0.406]\n",
"std = [0.229, 0.224, 0.225]\n",
"\n",
"training_transformer = transforms.Compose([\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomFlipLeftRight(),\n",
" transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,\n",
" saturation=jitter_param),\n",
" transforms.RandomLighting(lighting_param),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)\n",
"])\n",
"\n",
"validation_transformer = transforms.Compose([\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean, std)\n",
"])\n",
"\n",
"# save mean and std NDArray values for inference\n",
"mean_img = mx.nd.stack(*[mx.nd.full((224, 224), m) for m in mean])\n",
"std_img = mx.nd.stack(*[mx.nd.full((224, 224), s) for s in std])\n",
"mx.nd.save('mean_std_224.nd', {\"mean_img\": mean_img, \"std_img\": std_img})\n",
"\n",
"train_path = os.path.join(path, 'train')\n",
"val_path = os.path.join(path, 'valid')\n",
"test_path = os.path.join(path, 'test')\n",
"\n",
"# loading the data and apply pre-processing(transforms) on images\n",
"train_data = gluon.data.DataLoader(\n",
" gluon.data.vision.ImageFolderDataset(train_path).transform_first(training_transformer),\n",
" batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
"\n",
"val_data = gluon.data.DataLoader(\n",
" gluon.data.vision.ImageFolderDataset(val_path).transform_first(validation_transformer),\n",
" batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
"\n",
"test_data = gluon.data.DataLoader(\n",
" gluon.data.vision.ImageFolderDataset(test_path).transform_first(validation_transformer),\n",
" batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "fdbb5f43",
"metadata": {},
"source": [
"### Loading pre-trained model\n",
"\n",
"\n",
"We will use pre-trained ResNet50_v2 model which was pre-trained on the [ImageNet Dataset](http://www.image-net.org/) with 1000 classes. To match the classes in the Flower dataset, we must redefine the last softmax (output) layer to be 102, then initialize the parameters.\n",
"\n",
"Before we go to training, one unique Gluon feature you should be aware of is hybridization. It allows you to convert your imperative code to a static symbolic graph, which is much more efficient to execute. There are two main benefits of hybridizing your model: better performance and easier serialization for deployment. The best part is that it's as simple as just calling `net.hybridize()`. To know more about Gluon hybridization, please follow the [hybridization tutorial](/api/python/docs/tutorials/packages/gluon/blocks/hybridize.html)."
]
},
{
"cell_type": "markdown",
"id": "8cc7a4e2",
"metadata": {},
"source": [
"```python\n",
"# load pre-trained resnet50_v2 from model zoo\n",
"finetune_net = resnet50_v2(pretrained=True, ctx=ctx)\n",
"\n",
"# change last softmax layer since number of classes are different\n",
"with finetune_net.name_scope():\n",
" finetune_net.output = nn.Dense(classes)\n",
"finetune_net.output.initialize(init.Xavier(), ctx=ctx)\n",
"# hybridize for better performance\n",
"finetune_net.hybridize()\n",
"\n",
"num_batch = len(train_data)\n",
"\n",
"# setup learning rate scheduler\n",
"iterations_per_epoch = math.ceil(num_batch)\n",
"# learning rate change at following steps\n",
"lr_steps = [epoch * iterations_per_epoch for epoch in lr_epochs]\n",
"schedule = mx.lr_scheduler.MultiFactorScheduler(step=lr_steps, factor=lr_factor, base_lr=lr)\n",
"\n",
"# setup optimizer with learning rate scheduler, metric, and loss function\n",
"sgd_optimizer = mx.optimizer.SGD(learning_rate=lr, lr_scheduler=schedule, momentum=momentum, wd=wd)\n",
"metric = mx.metric.Accuracy()\n",
"softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "db45dd9d",
"metadata": {},
"source": [
"### Fine-tuning model on your custom dataset\n",
"\n",
"Now let's define the test metrics and start fine-tuning."
]
},
{
"cell_type": "markdown",
"id": "982645df",
"metadata": {},
"source": [
"```python\n",
"def test(net, val_data, ctx):\n",
" metric = mx.metric.Accuracy()\n",
" for i, (data, label) in enumerate(val_data):\n",
" data = gluon.utils.split_and_load(data, ctx_list=ctx, even_split=False)\n",
" label = gluon.utils.split_and_load(label, ctx_list=ctx, even_split=False)\n",
" outputs = [net(x) for x in data]\n",
" metric.update(label, outputs)\n",
" return metric.get()\n",
"\n",
"trainer = gluon.Trainer(finetune_net.collect_params(), optimizer=sgd_optimizer)\n",
"\n",
"# start with epoch 1 for easier learning rate calculation\n",
"for epoch in range(1, epochs + 1):\n",
"\n",
" tic = time.time()\n",
" train_loss = 0\n",
" metric.reset()\n",
"\n",
" for i, (data, label) in enumerate(train_data):\n",
" # get the images and labels\n",
" data = gluon.utils.split_and_load(data, ctx_list=ctx, even_split=False)\n",
" label = gluon.utils.split_and_load(label, ctx_list=ctx, even_split=False)\n",
" with autograd.record():\n",
" outputs = [finetune_net(x) for x in data]\n",
" loss = [softmax_cross_entropy(yhat, y) for yhat, y in zip(outputs, label)]\n",
" for l in loss:\n",
" l.backward()\n",
"\n",
" trainer.step(batch_size)\n",
" train_loss += sum([l.mean().asscalar() for l in loss]) / len(loss)\n",
" metric.update(label, outputs)\n",
"\n",
" _, train_acc = metric.get()\n",
" train_loss /= num_batch\n",
" _, val_acc = test(finetune_net, val_data, ctx)\n",
"\n",
" print('[Epoch %d] Train-acc: %.3f, loss: %.3f | Val-acc: %.3f | learning-rate: %.3E | time: %.1f' %\n",
" (epoch, train_acc, train_loss, val_acc, trainer.learning_rate, time.time() - tic))\n",
"\n",
"_, test_acc = test(finetune_net, test_data, ctx)\n",
"print('[Finished] Test-acc: %.3f' % (test_acc))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "7a27891c",
"metadata": {},
"source": [
"Following is the training result:"
]
},
{
"cell_type": "markdown",
"id": "65e6aceb",
"metadata": {},
"source": [
"```text\n",
"[Epoch 40] Train-acc: 0.945, loss: 0.354 | Val-acc: 0.955 | learning-rate: 4.219E-04 | time: 17.8\n",
"[Finished] Test-acc: 0.952\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "010fdf74",
"metadata": {},
"source": [
"In the previous example output, we trained the model using an [AWS p3.8xlarge instance](https://aws.amazon.com/ec2/instance-types/p3/) with 4 Tesla V100 GPUs. We were able to reach a test accuracy of 95.5% with 40 epochs in around 12 minutes. This was really fast because our model was pre-trained on a much larger dataset, ImageNet, with around 1.3 million images. It worked really well to capture features on our small dataset.\n",
"\n",
"\n",
"### Save the fine-tuned model\n",
"\n",
"\n",
"We now have a trained our custom model. This can be serialized into model files using the export function. The export function will export the model architecture into a `.json` file and model parameters into a `.params` file."
]
},
{
"cell_type": "markdown",
"id": "1b440f4e",
"metadata": {},
"source": [
"```python\n",
"finetune_net.export(\"flower-recognition\", epoch=epochs)\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "3d1c13c1",
"metadata": {},
"source": [
"`export` creates `flower-recognition-symbol.json` and `flower-recognition-0040.params` (`0040` is for 40 epochs we ran) in the current directory. These files can be used for model deployment in the next section.\n",
"\n",
"## Load the model and run inference using the MXNet Module API\n",
"\n",
"MXNet provides various useful tools and interfaces for deploying your model for inference. For example, you can use [MXNet Model Server](https://github.com/awslabs/mxnet-model-server) to start a service and host your trained model easily.\n",
"Besides that, you can also use MXNet's different language APIs to integrate your model with your existing service. We provide [Python](/api/python.html), [Java](/api/java.html), [Scala](/api/scala.html), and [C++](/api/cpp) APIs.\n",
"\n",
"Here we will briefly introduce how to run inference using Module API in Python. In general, prediction consists of the following steps:\n",
"1. Load the model architecture (symbol file) and trained parameter values (params file)\n",
"2. Load the synset file for label names\n",
"3. Load the image and apply the same transformation we did on validation dataset during training\n",
"4. Run a forward pass on the image data\n",
"5. Convert output probabilities to predicted label name"
]
},
{
"cell_type": "markdown",
"id": "304952ca",
"metadata": {},
"source": [
"```python\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"\n",
"ctx = mx.cpu()\n",
"# load model symbol and params\n",
"sym, arg_params, aux_params = mx.model.load_checkpoint('flower-recognition', epochs)\n",
"mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)\n",
"mod.bind(for_training=False, data_shapes=[('data', (1, 3, 224, 224))], label_shapes=mod._label_shapes)\n",
"mod.set_params(arg_params, aux_params, allow_missing=True)\n",
"\n",
"# load synset for label names\n",
"with open('synset.txt', 'r') as f:\n",
" labels = [l.rstrip() for l in f]\n",
"\n",
"# load an image for prediction\n",
"img = mx.image.imread('./data/test/lotus/image_01832.jpg')\n",
"# apply transform we did during training\n",
"img = validation_transformer(img)\n",
"# batchify\n",
"img = img.expand_dims(axis=0)\n",
"Batch = namedtuple('Batch', ['data'])\n",
"mod.forward(Batch([img]))\n",
"prob = mod.get_outputs()[0].asnumpy()\n",
"prob = np.squeeze(prob)\n",
"idx = np.argmax(prob)\n",
"print('probability=%f, class=%s' % (prob[idx], labels[idx]))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "7838af53",
"metadata": {},
"source": [
"Following is the output, you can see the image has been classified as lotus correctly."
]
},
{
"cell_type": "markdown",
"id": "4c2fae9e",
"metadata": {},
"source": [
"```text\n",
"probability=9.798435, class=lotus\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "bc59440b",
"metadata": {},
"source": [
"## What's next\n",
"\n",
"You can continue to the [next tutorial](/api/cpp/docs/tutorials/cpp_inference) on how to load the model we just trained and run inference using MXNet C++ API.\n",
"\n",
"You can also find more ways to run inference and deploy your models here:\n",
"1. [Java Inference examples](https://github.com/apache/mxnet/tree/master/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer)\n",
"2. [Scala Inference examples](/api/scala/docs/tutorials/infer)\n",
"4. [MXNet Model Server Examples](https://github.com/awslabs/mxnet-model-server/tree/master/examples)\n",
"\n",
"## References\n",
"\n",
"1. [Transfer Learning for Oxford102 Flower Dataset](https://github.com/Arsey/keras-transfer-learning-for-oxford102)\n",
"2. [Gluon book on fine-tuning](https://www.d2l.ai/chapter_computer-vision/fine-tuning.html)\n",
"3. [Gluon CV transfer learning tutorial](https://gluon-cv.mxnet.io/build/examples_classification/transfer_learning_minc.html)\n",
"4. [Gluon crash course](https://gluon-crash-course.mxnet.io/)\n",
"5. [Gluon CPP inference example](https://github.com/apache/mxnet/blob/master/cpp-package/example/inference/)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}