blob: 3ff7d57e43600c087bf10a3abd5ead8fd38fd822 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "954625fb",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"\n",
"# Running inference on MXNet/Gluon from an ONNX model\n",
"\n",
"[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n",
"\n",
"In this tutorial we will:\n",
"\n",
"- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n",
"- learn how to test this model using the sample input/output\n",
"- learn how to test the model on custom images\n",
"\n",
"## Pre-requisite\n",
"\n",
"To run the tutorial you will need to have installed the following python modules:\n",
"- [MXNet > 1.1.0](https://mxnet.apache.org/get_started)\n",
"- [onnx](https://github.com/onnx/onnx) (follow the install guide)\n",
"- matplotlib"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a4dc4d4",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import mxnet as mx\n",
"from mxnet.contrib import onnx as onnx_mxnet\n",
"from mxnet import gluon, nd\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import tarfile, os\n",
"import json\n",
"import logging\n",
"logging.basicConfig(level=logging.INFO)"
]
},
{
"cell_type": "markdown",
"id": "c8938f9e",
"metadata": {},
"source": [
"### Downloading supporting files\n",
"These are images and a vizualisation script"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e749c9c9",
"metadata": {},
"outputs": [],
"source": [
"image_folder = \"images\"\n",
"utils_file = \"utils.py\" # contain utils function to plot nice visualization\n",
"image_net_labels_file = \"image_net_labels.json\"\n",
"images = ['apron.jpg', 'hammerheadshark.jpg', 'dog.jpg', 'wrench.jpg', 'dolphin.jpg', 'lotus.jpg']\n",
"base_url = \"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/{}?raw=true\"\n",
"\n",
"for image in images:\n",
" mx.test_utils.download(base_url.format(\"{}/{}\".format(image_folder, image)), fname=image,dirname=image_folder)\n",
"mx.test_utils.download(base_url.format(utils_file), fname=utils_file)\n",
"mx.test_utils.download(base_url.format(image_net_labels_file), fname=image_net_labels_file)\n",
"\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"id": "4518dbdf",
"metadata": {},
"source": [
"## Downloading a model from the ONNX model zoo\n",
"\n",
"We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37b2df22",
"metadata": {},
"outputs": [],
"source": [
"base_url = \"https://s3.amazonaws.com/download.onnx/models/opset_3/\"\n",
"current_model = \"bvlc_googlenet\"\n",
"model_folder = \"model\"\n",
"archive = \"{}.tar.gz\".format(current_model)\n",
"archive_file = os.path.join(model_folder, archive)\n",
"url = \"{}{}\".format(base_url, archive)"
]
},
{
"cell_type": "markdown",
"id": "23e29ce7",
"metadata": {},
"source": [
"Download and extract pre-trained model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10cf2e1e",
"metadata": {},
"outputs": [],
"source": [
"mx.test_utils.download(url, dirname = model_folder)\n",
"if not os.path.isdir(os.path.join(model_folder, current_model)):\n",
" print('Extracting model...')\n",
" tar = tarfile.open(archive_file, \"r:gz\")\n",
" tar.extractall(model_folder)\n",
" tar.close()\n",
" print('Extracted')"
]
},
{
"cell_type": "markdown",
"id": "ac55756f",
"metadata": {},
"source": [
"The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e93ed870",
"metadata": {},
"outputs": [],
"source": [
"categories = json.load(open(image_net_labels_file, 'r'))"
]
},
{
"cell_type": "markdown",
"id": "aa43b716",
"metadata": {},
"source": [
"## Loading the model into MXNet Gluon"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e4c0cae",
"metadata": {},
"outputs": [],
"source": [
"onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")"
]
},
{
"cell_type": "markdown",
"id": "4fb3ef07",
"metadata": {},
"source": [
"We get the symbol and parameter objects"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9c13e57",
"metadata": {},
"outputs": [],
"source": [
"sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)"
]
},
{
"cell_type": "markdown",
"id": "e6908285",
"metadata": {},
"source": [
"We pick a device, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c08c7b5d",
"metadata": {},
"outputs": [],
"source": [
"device = mx.cpu()"
]
},
{
"cell_type": "markdown",
"id": "ae61bb3a",
"metadata": {},
"source": [
"We obtain the data names of the inputs to the model by using the model metadata API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "635c8818",
"metadata": {},
"outputs": [],
"source": [
"model_metadata = onnx_mxnet.get_model_metadata(onnx_path)\n",
"print(model_metadata)"
]
},
{
"cell_type": "markdown",
"id": "af7810b8",
"metadata": {},
"source": [
"```\n",
"{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],\n",
" 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cfdac34",
"metadata": {},
"outputs": [],
"source": [
"data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]\n",
"print(data_names)"
]
},
{
"cell_type": "markdown",
"id": "a8a3a8aa",
"metadata": {},
"source": [
"And load them into a MXNet Gluon symbol block."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80cd0957",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))\n",
"net_params = net.collect_params()\n",
"for param in arg_params:\n",
" if param in net_params:\n",
" net_params[param]._load_init(arg_params[param], device=device)\n",
"for param in aux_params:\n",
" if param in net_params:\n",
" net_params[param]._load_init(aux_params[param], device=device)"
]
},
{
"cell_type": "markdown",
"id": "069d723d",
"metadata": {},
"source": [
"We can now cache the computational graph through [hybridization](https://mxnet.apache.org/versions/master/api/python/docs/tutorials/packages/gluon/blocks/hybridize.html) to gain some performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "314dafe3",
"metadata": {},
"outputs": [],
"source": [
"net.hybridize()"
]
},
{
"cell_type": "markdown",
"id": "91aeaba8",
"metadata": {},
"source": [
"We can visualize the network (requires graphviz installed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65f60def",
"metadata": {},
"outputs": [],
"source": [
"mx.visualization.plot_network(sym, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})"
]
},
{
"cell_type": "markdown",
"id": "ace59190",
"metadata": {},
"source": [
"![network2](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/network2.png?raw=true)<!--notebook-skip-line-->\n",
"\n",
"\n",
"\n",
"This is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d424af63",
"metadata": {},
"outputs": [],
"source": [
"def run_batch(net, data):\n",
" results = []\n",
" for batch in data:\n",
" outputs = net(batch)\n",
" results.extend([o for o in outputs.asnumpy()])\n",
" return np.array(results)"
]
},
{
"cell_type": "markdown",
"id": "3d51763b",
"metadata": {},
"source": [
"## Test using real images"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e59f47",
"metadata": {},
"outputs": [],
"source": [
"TOP_P = 3 # How many top guesses we show in the visualization"
]
},
{
"cell_type": "markdown",
"id": "4519d05d",
"metadata": {},
"source": [
"Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acd10b01",
"metadata": {},
"outputs": [],
"source": [
"def transform(img):\n",
" return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)"
]
},
{
"cell_type": "markdown",
"id": "59f29daa",
"metadata": {},
"source": [
"We load two sets of images in memory"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90f0040f",
"metadata": {},
"outputs": [],
"source": [
"image_net_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['apron', 'hammerheadshark','dog']]\n",
"caltech101_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['wrench', 'dolphin','lotus']]\n",
"images = image_net_images + caltech101_images"
]
},
{
"cell_type": "markdown",
"id": "b7306dfb",
"metadata": {},
"source": [
"And run them as a batch through the network to get the predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eb1aa79",
"metadata": {},
"outputs": [],
"source": [
"batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), device=device)\n",
"result = run_batch(net, [batch])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fee3e6b",
"metadata": {},
"outputs": [],
"source": [
"plot_predictions(image_net_images, result[:3], categories, TOP_P)"
]
},
{
"cell_type": "markdown",
"id": "4643d505",
"metadata": {},
"source": [
"![imagenet](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/imagenet.png?raw=true)<!--notebook-skip-line-->\n",
"\n",
"\n",
"**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n",
"\n",
"Let's now see the results on the 3 other images"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f6d1ee8",
"metadata": {},
"outputs": [],
"source": [
"plot_predictions(caltech101_images, result[3:7], categories, TOP_P)"
]
},
{
"cell_type": "markdown",
"id": "110f4ef4",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/caltech101.png?raw=true)<!--notebook-skip-line-->\n",
"\n",
"\n",
"**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n",
"\n",
"Lucky for us, the [Caltech101 dataset](https://data.caltech.edu/records/20086) has them, let's see how we can fine-tune our network to classify these categories correctly.\n",
"\n",
"We show that in our next tutorial:\n",
"\n",
"\n",
"- [Fine-tuning an ONNX Model using the modern imperative MXNet/Gluon](https://mxnet.apache.org/versions/master/api/python/docs/tutorials/packages/onnx/fine_tuning_gluon.html)\n",
"\n",
"<!-- INSERT SOURCE DOWNLOAD BUTTONS -->"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}