| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "a9cb12df", |
| "metadata": {}, |
| "source": [ |
| "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n", |
| "<!--- or more contributor license agreements. See the NOTICE file -->\n", |
| "<!--- distributed with this work for additional information -->\n", |
| "<!--- regarding copyright ownership. The ASF licenses this file -->\n", |
| "<!--- to you under the Apache License, Version 2.0 (the -->\n", |
| "<!--- \"License\"); you may not use this file except in compliance -->\n", |
| "<!--- with the License. You may obtain a copy of the License at -->\n", |
| "\n", |
| "<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n", |
| "\n", |
| "<!--- Unless required by applicable law or agreed to in writing, -->\n", |
| "<!--- software distributed under the License is distributed on an -->\n", |
| "<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n", |
| "<!--- KIND, either express or implied. See the License for the -->\n", |
| "<!--- specific language governing permissions and limitations -->\n", |
| "<!--- under the License. -->\n", |
| "\n", |
| "\n", |
| "# Fine-tuning an ONNX model\n", |
| "\n", |
| "Fine-tuning is a common practice in Transfer Learning. One can take advantage of the pre-trained weights of a network, and use them as an initializer for their own task. Indeed, quite often it is difficult to gather a dataset large enough that it would allow training from scratch deep and complex networks such as ResNet152 or VGG16. For example in an image classification task, using a network trained on a large dataset like ImageNet gives a good base from which the weights can be slightly updated, or fine-tuned, to predict accurately the new classes. We will see in this tutorial that this can be achieved even with a relatively small number of new training examples.\n", |
| "\n", |
| "\n", |
| "[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n", |
| "\n", |
| "In this tutorial we will:\n", |
| "\n", |
| "- learn how to pick a specific layer from a pre-trained .onnx model file\n", |
| "- learn how to load this model in Gluon and fine-tune it on a different dataset\n", |
| "\n", |
| "## Pre-requisite\n", |
| "\n", |
| "To run the tutorial you will need to have installed the following python modules:\n", |
| "- [MXNet > 1.1.0](https://mxnet.apache.org/get_started)\n", |
| "- [onnx](https://github.com/onnx/onnx)\n", |
| "- matplotlib\n", |
| "\n", |
| "We recommend that you have first followed this tutorial:\n", |
| "- [Inference using an ONNX model on MXNet Gluon](./inference_on_onnx_model.ipynb)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "85c03a54", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import json\n", |
| "import logging\n", |
| "import multiprocessing\n", |
| "import os\n", |
| "import tarfile\n", |
| "\n", |
| "logging.basicConfig(level=logging.INFO)\n", |
| "\n", |
| "import matplotlib.pyplot as plt\n", |
| "import mxnet as mx\n", |
| "from mxnet import gluon, np, npx, autograd\n", |
| "from mxnet.gluon.data.vision.datasets import ImageFolderDataset\n", |
| "from mxnet.gluon.data import DataLoader\n", |
| "import mxnet.contrib.onnx as onnx_mxnet\n", |
| "import numpy as onp\n", |
| "\n", |
| "%matplotlib inline" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "443b027f", |
| "metadata": {}, |
| "source": [ |
| "### Downloading supporting files\n", |
| "These are images and a vizualisation script:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "278a1bb8", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "image_folder = \"images\"\n", |
| "utils_file = \"utils.py\" # contain utils function to plot nice visualization\n", |
| "images = ['wrench.jpg', 'dolphin.jpg', 'lotus.jpg']\n", |
| "base_url = \"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/{}?raw=true\"\n", |
| "\n", |
| "\n", |
| "for image in images:\n", |
| " mx.test_utils.download(base_url.format(\"{}/{}\".format(image_folder, image)), fname=image,dirname=image_folder)\n", |
| "mx.test_utils.download(base_url.format(utils_file), fname=utils_file)\n", |
| "\n", |
| "from utils import *" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1daafde4", |
| "metadata": {}, |
| "source": [ |
| "## Downloading a model from the ONNX model zoo\n", |
| "\n", |
| "We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "7033dddb", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "base_url = \"https://s3.amazonaws.com/download.onnx/models/opset_3/\"\n", |
| "current_model = \"bvlc_googlenet\"\n", |
| "model_folder = \"model\"\n", |
| "archive_file = \"{}.tar.gz\".format(current_model)\n", |
| "archive_path = os.path.join(model_folder, archive_file)\n", |
| "url = \"{}{}\".format(base_url, archive_file)\n", |
| "onnx_path = os.path.join(model_folder, current_model, 'model.onnx')\n", |
| "\n", |
| "# Download the zipped model\n", |
| "mx.test_utils.download(url, dirname = model_folder)\n", |
| "\n", |
| "# Extract the model\n", |
| "if not os.path.isdir(os.path.join(model_folder, current_model)):\n", |
| " print('Extracting {} in {}...'.format(archive_path, model_folder))\n", |
| " tar = tarfile.open(archive_path, \"r:gz\")\n", |
| " tar.extractall(model_folder)\n", |
| " tar.close()\n", |
| " print('Model extracted.')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d871c2c5", |
| "metadata": {}, |
| "source": [ |
| "## Downloading the Caltech101 dataset\n", |
| "\n", |
| "The [Caltech101 dataset](https://data.caltech.edu/records/20086) is made of pictures of objects belonging to 101 categories. About 40 to 800 images per category. Most categories have about 50 images.\n", |
| "\n", |
| "*L. Fei-Fei, R. Fergus and P. Perona. Learning generative visual models from few training examples: an incremental Bayesian approach tested on 101 object categories. IEEE. CVPR 2004, Workshop on Generative-Model\n", |
| "Based Vision. 2004*" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "5929a9f4", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "data_folder = \"data\"\n", |
| "dataset_name = \"101_ObjectCategories\"\n", |
| "archive_file = \"{}.tar.gz\".format(dataset_name)\n", |
| "archive_path = os.path.join(data_folder, archive_file)\n", |
| "data_url = \"https://s3.us-east-2.amazonaws.com/mxnet-public/\"\n", |
| "\n", |
| "if not os.path.isfile(archive_path):\n", |
| " mx.test_utils.download(\"{}{}\".format(data_url, archive_file), dirname = data_folder)\n", |
| " print('Extracting {} in {}...'.format(archive_file, data_folder))\n", |
| " tar = tarfile.open(archive_path, \"r:gz\")\n", |
| " tar.extractall(data_folder)\n", |
| " tar.close()\n", |
| " print('Data extracted.')" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "8dea4e2e", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "training_path = os.path.join(data_folder, dataset_name)\n", |
| "testing_path = os.path.join(data_folder, \"{}_test\".format(dataset_name))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "41eee09b", |
| "metadata": {}, |
| "source": [ |
| "### Load the data using an ImageFolderDataset and a DataLoader\n", |
| "\n", |
| "We need to transform the images to a format accepted by the network" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "eacf5634", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "EDGE = 224\n", |
| "SIZE = (EDGE, EDGE)\n", |
| "BATCH_SIZE = 32\n", |
| "NUM_WORKERS = 6" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "447d0423", |
| "metadata": {}, |
| "source": [ |
| "We transform the dataset images using the following operations:\n", |
| "- resize the shorter edge to 224, the longer edge will be greater or equal to 224\n", |
| "- center and crop an area of size (224,224)\n", |
| "- transpose the channels to be (3,224,224)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "5449a7d8", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def transform(image, label):\n", |
| " resized = mx.image.resize_short(image, EDGE)\n", |
| " cropped, crop_info = mx.image.center_crop(resized, SIZE)\n", |
| " transposed = np.transpose(cropped, (2,0,1))\n", |
| " return transposed, label" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "74468ea0", |
| "metadata": {}, |
| "source": [ |
| "The train and test dataset are created automatically by passing the root of each folder. The labels are built using the sub-folders names as label." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "96692f77", |
| "metadata": {}, |
| "source": [ |
| "```\n", |
| "train_root\n", |
| "__label1\n", |
| "____image1\n", |
| "____image2\n", |
| "__label2\n", |
| "____image3\n", |
| "____image4\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "9f846f61", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "dataset_train = ImageFolderDataset(root=training_path)\n", |
| "dataset_test = ImageFolderDataset(root=testing_path)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "eb7b97c4", |
| "metadata": {}, |
| "source": [ |
| "We use several worker processes, which means the dataloading and pre-processing is going to be distributed across multiple processes. This will help preventing our GPU from starving and waiting for the data to be copied across" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "2c0062bb", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "dataloader_train = DataLoader(dataset_train.transform(transform, lazy=False), batch_size=BATCH_SIZE, last_batch='rollover',\n", |
| " shuffle=True, num_workers=NUM_WORKERS)\n", |
| "dataloader_test = DataLoader(dataset_test.transform(transform, lazy=False), batch_size=BATCH_SIZE, last_batch='rollover',\n", |
| " shuffle=False, num_workers=NUM_WORKERS)\n", |
| "print(\"Train dataset: {} images, Test dataset: {} images\".format(len(dataset_train), len(dataset_test)))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "10e9a603", |
| "metadata": {}, |
| "source": [ |
| "`Train dataset: 6996 images, Test dataset: 1681 images`<!--notebook-skip-line-->" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "aece3021", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "categories = dataset_train.synsets\n", |
| "NUM_CLASSES = len(categories)\n", |
| "BATCH_SIZE = 32" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a75450b3", |
| "metadata": {}, |
| "source": [ |
| "Let's plot the 1000th image to test the dataset" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "6ee85c09", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "N = 1000\n", |
| "plt.imshow((transform(dataset_train[N][0], 0)[0].asnumpy().transpose((1,2,0))))\n", |
| "plt.axis('off')\n", |
| "print(categories[dataset_train[N][1]])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "c6739c94", |
| "metadata": {}, |
| "source": [ |
| "`Motorbikes`<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "\n", |
| "![onnx motorbike](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/onnx/motorbike.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "## Fine-Tuning the ONNX model\n", |
| "\n", |
| "### Getting the last layer\n", |
| "\n", |
| "Load the ONNX model" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "8ad96605", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "fcade35d", |
| "metadata": {}, |
| "source": [ |
| "This function get the output of a given layer" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "4dbe7f52", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def get_layer_output(symbol, arg_params, aux_params, layer_name):\n", |
| " all_layers = symbol.get_internals()\n", |
| " net = all_layers[layer_name+'_output']\n", |
| " net = mx.symbol.Flatten(data=net)\n", |
| " new_args = dict({k:arg_params[k] for k in arg_params if k in net.list_arguments()})\n", |
| " new_aux = dict({k:aux_params[k] for k in aux_params if k in net.list_arguments()})\n", |
| " return (net, new_args, new_aux)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3f25c425", |
| "metadata": {}, |
| "source": [ |
| "Here we print the different layers of the network to make it easier to pick the right one" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "a10935a7", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "sym.get_internals()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "9c5a264a", |
| "metadata": {}, |
| "source": [ |
| "```<Symbol group [data_0, pad0, conv1/7x7_s2_w_0, conv1/7x7_s2_b_0, convolution0, relu0, pad1, pooling0, lrn0, pad2, conv2/3x3_reduce_w_0, conv2/3x3_reduce_b_0, convolution1, relu1, pad3, conv2/3x3_w_0, conv2/3x3_b_0, convolution2, relu2, lrn1, pad4, pooling1, pad5, inception_3a/1x1_w_0, inception_3a/1x1_b_0, convolution3, relu3, pad6, .................................................................................inception_5b/pool_proj_b_0, convolution56, relu56, concat8, pad70, pooling13, dropout0, flatten0, loss3/classifier_w_0, linalg_gemm20, loss3/classifier_b_0, _mulscalar0, broadcast_add0, softmax0]>```<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "\n", |
| "We get the network until the output of the `flatten0` layer\n", |
| "\n", |
| "\n", |
| "```{.python .input}\n", |
| "new_sym, new_arg_params, new_aux_params = get_layer_output(sym, arg_params, aux_params, 'flatten0')\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "5bfac66e", |
| "metadata": {}, |
| "source": [ |
| "### Fine-tuning in gluon\n", |
| "\n", |
| "\n", |
| "We can now take advantage of the features and pattern detection knowledge that our network learnt training on ImageNet, and apply that to the new Caltech101 dataset.\n", |
| "\n", |
| "\n", |
| "We pick a device, fine-tuning on CPU will be **WAY** slower." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "4ea751e4", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "device = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "925bae68", |
| "metadata": {}, |
| "source": [ |
| "We create a symbol block that is going to hold all our pre-trained layers, and assign the weights of the different pre-trained layers to the newly created SymbolBlock" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "2495b17d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import warnings\n", |
| "with warnings.catch_warnings():\n", |
| " warnings.simplefilter(\"ignore\")\n", |
| " pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('data_0'))\n", |
| "net_params = pre_trained.collect_params()\n", |
| "for param in new_arg_params:\n", |
| " if param in net_params:\n", |
| " net_params[param]._load_init(new_arg_params[param], device=device)\n", |
| "for param in new_aux_params:\n", |
| " if param in net_params:\n", |
| " net_params[param]._load_init(new_aux_params[param], device=device)\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "eec7cfec", |
| "metadata": {}, |
| "source": [ |
| "We create the new dense layer with the right new number of classes (101) and initialize the weights" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "7e6378b5", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "dense_layer = gluon.nn.Dense(NUM_CLASSES)\n", |
| "dense_layer.initialize(mx.init.Xavier(magnitude=2.24), device=device)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ced989fd", |
| "metadata": {}, |
| "source": [ |
| "We add the SymbolBlock and the new dense layer to a HybridSequential network" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "a863b765", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "net = gluon.nn.HybridSequential()\n", |
| "net.add(pre_trained)\n", |
| "net.add(dense_layer)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "af7cc648", |
| "metadata": {}, |
| "source": [ |
| "### Loss\n", |
| "Softmax cross entropy for multi-class classification" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "3cbcc1f2", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "15f61da3", |
| "metadata": {}, |
| "source": [ |
| "### Trainer\n", |
| "Initialize trainer with common training parameters" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "47345c94", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "LEARNING_RATE = 0.0005\n", |
| "WDECAY = 0.00001\n", |
| "MOMENTUM = 0.9" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4418a5f7", |
| "metadata": {}, |
| "source": [ |
| "The trainer will retrain and fine-tune the entire network. If we use `dense_layer` instead of `net` in the cell below, the gradient updates would only be applied to the new last dense layer. Essentially we would be using the pre-trained network as a featurizer." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "fd9bb942", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "trainer = gluon.Trainer(net.collect_params(), 'sgd',\n", |
| " {'learning_rate': LEARNING_RATE,\n", |
| " 'wd':WDECAY,\n", |
| " 'momentum':MOMENTUM})" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a23d3cb7", |
| "metadata": {}, |
| "source": [ |
| "### Evaluation loop\n", |
| "\n", |
| "We measure the accuracy in a non-blocking way, using `np.array` to take care of the parallelisation that MXNet and Gluon offers." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "12cfbcec", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| " def evaluate_accuracy_gluon(data_iterator, net):\n", |
| " num_instance = 0\n", |
| " sum_metric = np.zeros(1,device=device, dtype=np.int32)\n", |
| " for i, (data, label) in enumerate(data_iterator):\n", |
| " data = data.astype(np.float32).to_device(device)\n", |
| " label = label.astype(np.int32).to_device(device)\n", |
| " output = net(data)\n", |
| " prediction = np.argmax(output, axis=1).astype(np.int32)\n", |
| " num_instance += len(prediction)\n", |
| " sum_metric += (prediction==label).sum()\n", |
| " accuracy = (sum_metric.astype(np.float32)/num_instance)\n", |
| " return accuracy.item()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "02c13f7b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%time\n", |
| "print(\"Untrained network Test Accuracy: {0:.4f}\".format(evaluate_accuracy_gluon(dataloader_test, net)))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1f4cacf7", |
| "metadata": {}, |
| "source": [ |
| "`Untrained network Test Accuracy: 0.0192`<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "\n", |
| "### Training loop" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "c05a268b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "val_accuracy = 0\n", |
| "for epoch in range(5):\n", |
| " for i, (data, label) in enumerate(dataloader_train):\n", |
| " data = data.astype(np.float32).to_device(device)\n", |
| " label = label.to_device(device)\n", |
| "\n", |
| " if i%20==0 and i >0:\n", |
| " print('Batch [{0}] loss: {1:.4f}'.format(i, loss.mean().item()))\n", |
| "\n", |
| " with autograd.record():\n", |
| " output = net(data)\n", |
| " loss = softmax_cross_entropy(output, label)\n", |
| " loss.backward()\n", |
| " trainer.step(data.shape[0])\n", |
| "\n", |
| " npx.waitall() # wait at the end of the epoch\n", |
| " new_val_accuracy = evaluate_accuracy_gluon(dataloader_test, net)\n", |
| " print(\"Epoch [{0}] Test Accuracy {1:.4f} \".format(epoch, new_val_accuracy))\n", |
| "\n", |
| " # We perform early-stopping regularization, to prevent the model from overfitting\n", |
| " if val_accuracy > new_val_accuracy:\n", |
| " print('Validation accuracy is decreasing, stopping training')\n", |
| " break\n", |
| " val_accuracy = new_val_accuracy" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7b45a49c", |
| "metadata": {}, |
| "source": [ |
| "`Epoch 4, Test Accuracy 0.8942307829856873`<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "## Testing\n", |
| "In the previous tutorial, we saw that the network trained on ImageNet couldn't classify correctly `wrench`, `dolphin`, `lotus` because these are not categories of the ImageNet dataset.\n", |
| "\n", |
| "Let's see if our network fine-tuned on Caltech101 is up for the task:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "5d5352a5", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Number of predictions to show\n", |
| "TOP_P = 3" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "2a8cca99", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Convert img to format expected by the network\n", |
| "def transform(img):\n", |
| " return np.array(np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32), device=device)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "c8ed9dfd", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Load and transform the test images\n", |
| "caltech101_images_test = [plt.imread(os.path.join(image_folder, \"{}\".format(img))) for img in images]\n", |
| "caltech101_images_transformed = [transform(img) for img in caltech101_images_test]" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3d78d13d", |
| "metadata": {}, |
| "source": [ |
| "Helper function to run batches of data" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "f25407c6", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def run_batch(net, data):\n", |
| " results = []\n", |
| " for batch in data:\n", |
| " outputs = net(batch)\n", |
| " results.extend([o for o in outputs.asnumpy()])\n", |
| " return np.array(results)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "bcb32df4", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "result = run_batch(net, caltech101_images_transformed)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "c825fe29", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "plot_predictions(caltech101_images_test, result, categories, TOP_P)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7e948cdb", |
| "metadata": {}, |
| "source": [ |
| "![onnx caltech101 correct](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/onnx/caltech101_correct.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "**Great!** The network classified these images correctly after being fine-tuned on a dataset that contains images of `wrench`, `dolphin` and `lotus`" |
| ] |
| } |
| ], |
| "metadata": { |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |