blob: 84f41d50a8db4300662c4873ad8e8eff3956ef09 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "1c491a43",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Predict with a pre-trained model\n",
"\n",
"A saved model can be used in multiple places, such as to continue training, to fine tune the model, and for prediction. In this tutorial we will discuss how to predict new examples using a pretrained model.\n",
"\n",
"## Prerequisites\n",
"\n",
"Please run the [previous tutorial](4-train.html) to train the network and save its parameters to file. You will need this file to run the following steps."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bc27e3ee",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"from mxnet import nd\n",
"from mxnet import gluon\n",
"from mxnet.gluon import nn\n",
"from mxnet.gluon.data.vision import datasets, transforms\n",
"from IPython import display\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "090d114c",
"metadata": {},
"source": [
"To start, we will copy a simple model's definition."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "51a2071d",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"net = nn.Sequential()\n",
"net.add(nn.Conv2D(channels=6, kernel_size=5, activation='relu'),\n",
" nn.MaxPool2D(pool_size=2, strides=2),\n",
" nn.Conv2D(channels=16, kernel_size=3, activation='relu'),\n",
" nn.MaxPool2D(pool_size=2, strides=2),\n",
" nn.Flatten(),\n",
" nn.Dense(120, activation=\"relu\"),\n",
" nn.Dense(84, activation=\"relu\"),\n",
" nn.Dense(10))"
]
},
{
"cell_type": "markdown",
"id": "0eb394ab",
"metadata": {},
"source": [
"In the last section, we saved all parameters into a file, now let's load it back."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c3f0db04",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"net.load_parameters('net.params')"
]
},
{
"cell_type": "markdown",
"id": "bf83b73f",
"metadata": {},
"source": [
"## Predict\n",
"\n",
"Remember the data transformation we did for training? Now we need the same transformation for predicting."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9796e13a",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"transformer = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(0.13, 0.31)])"
]
},
{
"cell_type": "markdown",
"id": "af3879d9",
"metadata": {},
"source": [
"Now let's try to predict the first six images in the validation dataset and store the predictions into `preds`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ad1313a5",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "5"
}
},
"outputs": [],
"source": [
"mnist_valid = datasets.FashionMNIST(train=False)\n",
"X, y = mnist_valid[:10]\n",
"preds = []\n",
"for x in X:\n",
" x = transformer(x).expand_dims(axis=0)\n",
" pred = net(x).argmax(axis=1)\n",
" preds.append(pred.astype('int32').asscalar())"
]
},
{
"cell_type": "markdown",
"id": "1c676690",
"metadata": {},
"source": [
"Finally, we visualize the images and compare the prediction with the ground truth."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "37e413f1",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "15"
}
},
"outputs": [],
"source": [
"_, figs = plt.subplots(1, 10, figsize=(15, 15))\n",
"text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
" 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
"display.set_matplotlib_formats('svg')\n",
"for f,x,yi,pyi in zip(figs, X, y, preds):\n",
" f.imshow(x.reshape((28,28)).asnumpy())\n",
" ax = f.axes\n",
" ax.set_title(text_labels[yi]+'\\n'+text_labels[pyi])\n",
" ax.title.set_fontsize(14)\n",
" ax.get_xaxis().set_visible(False)\n",
" ax.get_yaxis().set_visible(False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b8f46119",
"metadata": {},
"source": [
"## Predict with models from Gluon model zoo\n",
"\n",
"\n",
"The LeNet trained on FashionMNIST is a good example to start with, but too simple to predict real-life pictures. Instead of training large-scale model from scratch, [Gluon model zoo](https://mxnet.apache.org/api/python/gluon/model_zoo.html) provides multiple pre-trained powerful models. For example, we can download and load a pre-trained ResNet-50 V2 model that was trained on the ImageNet dataset."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9f6b6054",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
}
},
"outputs": [],
"source": [
"from mxnet.gluon.model_zoo import vision as models\n",
"from mxnet.gluon.utils import download\n",
"from mxnet import image\n",
"\n",
"net = models.resnet50_v2(pretrained=True)"
]
},
{
"cell_type": "markdown",
"id": "9b523ec4",
"metadata": {},
"source": [
"We also download and load the text labels for each class."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9bcef5db",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "8"
}
},
"outputs": [],
"source": [
"url = 'http://data.mxnet.io/models/imagenet/synset.txt'\n",
"fname = download(url)\n",
"with open(fname, 'r') as f:\n",
" text_labels = [' '.join(l.split()[1:]) for l in f]"
]
},
{
"cell_type": "markdown",
"id": "a13feee8",
"metadata": {},
"source": [
"We randomly pick a dog image from Wikipedia as a test image, download and read it."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "acd48643",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "9"
}
},
"outputs": [],
"source": [
"url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b5/\\\n",
"Golden_Retriever_medium-to-light-coat.jpg/\\\n",
"365px-Golden_Retriever_medium-to-light-coat.jpg'\n",
"fname = download(url)\n",
"x = image.imread(fname)"
]
},
{
"cell_type": "markdown",
"id": "86090e7f",
"metadata": {},
"source": [
"Following the conventional way of preprocessing ImageNet data:\n",
"\n",
"1. Resize the short edge into 256 pixes,\n",
"2. And then perform a center crop to obtain a 224-by-224 image."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b3186bf4",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "10"
}
},
"outputs": [],
"source": [
"x = image.resize_short(x, 256)\n",
"x, _ = image.center_crop(x, (224,224))\n",
"plt.imshow(x.asnumpy())\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7db22b99",
"metadata": {},
"source": [
"Now you may know it is a golden retriever (You can also infer it from the image URL).\n",
"\n",
"The futher data transformation is similar to FashionMNIST except that we subtract the RGB means and divide by the corresponding variances to normalize each color channel."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4bb69e21",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "11"
}
},
"outputs": [],
"source": [
"def transform(data):\n",
" data = data.transpose((2,0,1)).expand_dims(axis=0)\n",
" rgb_mean = nd.array([0.485, 0.456, 0.406]).reshape((1,3,1,1))\n",
" rgb_std = nd.array([0.229, 0.224, 0.225]).reshape((1,3,1,1))\n",
" return (data.astype('float32') / 255 - rgb_mean) / rgb_std"
]
},
{
"cell_type": "markdown",
"id": "ede29f1d",
"metadata": {},
"source": [
"Now we can recognize the object in the image now. We perform an additional softmax on the output to obtain probability scores. And then print the top-5 recognized objects."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2e43d327",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "12"
}
},
"outputs": [],
"source": [
"prob = net(transform(x)).softmax()\n",
"idx = prob.topk(k=5)[0]\n",
"for i in idx:\n",
" i = int(i.asscalar())\n",
" print('With prob = %.5f, it contains %s' % (\n",
" prob[0,i].asscalar(), text_labels[i]))"
]
},
{
"cell_type": "markdown",
"id": "06602be5",
"metadata": {},
"source": [
"As can be seen, the model is fairly confident the image contains a golden retriever."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}