blob: 4440fe48fd12d0ce1b5879f58ee572a3194fe770 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "5dc266a6",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"\n",
"# Running inference on MXNet/Gluon from an ONNX model\n",
"\n",
"[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n",
"\n",
"In this tutorial we will:\n",
" \n",
"- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n",
"- learn how to test this model using the sample input/output\n",
"- learn how to test the model on custom images\n",
"\n",
"## Pre-requisite\n",
"\n",
"To run the tutorial you will need to have installed the following python modules:\n",
"- [MXNet > 1.1.0](/get_started)\n",
"- [onnx](https://github.com/onnx/onnx) (follow the install guide)\n",
"- matplotlib"
]
},
{
"cell_type": "markdown",
"id": "8c76a8d2",
"metadata": {},
"source": [
"```python\n",
"import numpy as np\n",
"import mxnet as mx\n",
"from mxnet.contrib import onnx as onnx_mxnet\n",
"from mxnet import gluon, nd\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import tarfile, os\n",
"import json\n",
"import logging\n",
"logging.basicConfig(level=logging.INFO)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "d188ce5b",
"metadata": {},
"source": [
"### Downloading supporting files\n",
"These are images and a vizualisation script"
]
},
{
"cell_type": "markdown",
"id": "211651cd",
"metadata": {},
"source": [
"```python\n",
"image_folder = \"images\"\n",
"utils_file = \"utils.py\" # contain utils function to plot nice visualization\n",
"image_net_labels_file = \"image_net_labels.json\"\n",
"images = ['apron.jpg', 'hammerheadshark.jpg', 'dog.jpg', 'wrench.jpg', 'dolphin.jpg', 'lotus.jpg']\n",
"base_url = \"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/{}?raw=true\"\n",
"\n",
"for image in images:\n",
" mx.test_utils.download(base_url.format(\"{}/{}\".format(image_folder, image)), fname=image,dirname=image_folder)\n",
"mx.test_utils.download(base_url.format(utils_file), fname=utils_file)\n",
"mx.test_utils.download(base_url.format(image_net_labels_file), fname=image_net_labels_file)\n",
"\n",
"from utils import * \n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "0cdc8bba",
"metadata": {},
"source": [
"## Downloading a model from the ONNX model zoo\n",
"\n",
"We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file."
]
},
{
"cell_type": "markdown",
"id": "b0c8b382",
"metadata": {},
"source": [
"```python\n",
"base_url = \"https://s3.amazonaws.com/download.onnx/models/opset_3/\" \n",
"current_model = \"bvlc_googlenet\"\n",
"model_folder = \"model\"\n",
"archive = \"{}.tar.gz\".format(current_model)\n",
"archive_file = os.path.join(model_folder, archive)\n",
"url = \"{}{}\".format(base_url, archive)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "aa580b25",
"metadata": {},
"source": [
"Download and extract pre-trained model"
]
},
{
"cell_type": "markdown",
"id": "c154889c",
"metadata": {},
"source": [
"```python\n",
"mx.test_utils.download(url, dirname = model_folder)\n",
"if not os.path.isdir(os.path.join(model_folder, current_model)):\n",
" print('Extracting model...')\n",
" tar = tarfile.open(archive_file, \"r:gz\")\n",
" tar.extractall(model_folder)\n",
" tar.close()\n",
" print('Extracted')\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "7dc9b814",
"metadata": {},
"source": [
"The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes."
]
},
{
"cell_type": "markdown",
"id": "c8e7a766",
"metadata": {},
"source": [
"```python\n",
"categories = json.load(open(image_net_labels_file, 'r'))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "cf43e1b9",
"metadata": {},
"source": [
"## Loading the model into MXNet Gluon"
]
},
{
"cell_type": "markdown",
"id": "5cbe4eeb",
"metadata": {},
"source": [
"```python\n",
"onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "20d6b1ad",
"metadata": {},
"source": [
"We get the symbol and parameter objects"
]
},
{
"cell_type": "markdown",
"id": "e578d2b4",
"metadata": {},
"source": [
"```python\n",
"sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "06055ac0",
"metadata": {},
"source": [
"We pick a context, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU."
]
},
{
"cell_type": "markdown",
"id": "72ba4cd3",
"metadata": {},
"source": [
"```python\n",
"ctx = mx.cpu()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "862663e8",
"metadata": {},
"source": [
"We obtain the data names of the inputs to the model by using the model metadata API:"
]
},
{
"cell_type": "markdown",
"id": "b73a131d",
"metadata": {},
"source": [
"```python\n",
"model_metadata = onnx_mxnet.get_model_metadata(onnx_path)\n",
"print(model_metadata)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "9c70eb04",
"metadata": {},
"source": [
"```\n",
"{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],\n",
" 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "77a3f404",
"metadata": {},
"source": [
"```python\n",
"data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]\n",
"print(data_names)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "eec3c48b",
"metadata": {},
"source": [
"```[u'data_0']```<!--notebook-skip-line-->\n",
"\n",
"And load them into a MXNet Gluon symbol block. \n",
"\n",
"```python\n",
"import warnings\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))\n",
"net_params = net.collect_params()\n",
"for param in arg_params:\n",
" if param in net_params:\n",
" net_params[param]._load_init(arg_params[param], ctx=ctx)\n",
"for param in aux_params:\n",
" if param in net_params:\n",
" net_params[param]._load_init(aux_params[param], ctx=ctx)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "f1cd7794",
"metadata": {},
"source": [
"We can now cache the computational graph through [hybridization](https://mxnet.apache.org/tutorials/gluon/hybrid.html) to gain some performance"
]
},
{
"cell_type": "markdown",
"id": "8f609c7e",
"metadata": {},
"source": [
"```python\n",
"net.hybridize()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "a6d8351b",
"metadata": {},
"source": [
"We can visualize the network (requires graphviz installed)"
]
},
{
"cell_type": "markdown",
"id": "9d3d7319",
"metadata": {},
"source": [
"```python\n",
"mx.visualization.plot_network(sym, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "584cefe7",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/network2.png?raw=true)<!--notebook-skip-line-->\n",
"\n",
"\n",
"\n",
"This is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network."
]
},
{
"cell_type": "markdown",
"id": "fbb957cb",
"metadata": {},
"source": [
"```python\n",
"def run_batch(net, data):\n",
" results = []\n",
" for batch in data:\n",
" outputs = net(batch)\n",
" results.extend([o for o in outputs.asnumpy()])\n",
" return np.array(results)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "97b0ca03",
"metadata": {},
"source": [
"## Test using real images"
]
},
{
"cell_type": "markdown",
"id": "bebd2386",
"metadata": {},
"source": [
"```python\n",
"TOP_P = 3 # How many top guesses we show in the visualization\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "9a3a6c60",
"metadata": {},
"source": [
"Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size."
]
},
{
"cell_type": "markdown",
"id": "ffc31b50",
"metadata": {},
"source": [
"```python\n",
"def transform(img):\n",
" return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "a056eade",
"metadata": {},
"source": [
"We load two sets of images in memory"
]
},
{
"cell_type": "markdown",
"id": "5b30de3b",
"metadata": {},
"source": [
"```python\n",
"image_net_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['apron', 'hammerheadshark','dog']]\n",
"caltech101_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['wrench', 'dolphin','lotus']]\n",
"images = image_net_images + caltech101_images\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "d1abba01",
"metadata": {},
"source": [
"And run them as a batch through the network to get the predictions"
]
},
{
"cell_type": "markdown",
"id": "fcfbfa26",
"metadata": {},
"source": [
"```python\n",
"batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), ctx=ctx)\n",
"result = run_batch(net, [batch])\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "29dbf9c5",
"metadata": {},
"source": [
"```python\n",
"plot_predictions(image_net_images, result[:3], categories, TOP_P)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "036d69cd",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/imagenet.png?raw=true)<!--notebook-skip-line-->\n",
"\n",
"\n",
"**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n",
"\n",
"Let's now see the results on the 3 other images"
]
},
{
"cell_type": "markdown",
"id": "1b3c367d",
"metadata": {},
"source": [
"```python\n",
"plot_predictions(caltech101_images, result[3:7], categories, TOP_P)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "3fa9d619",
"metadata": {},
"source": [
"![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/caltech101.png?raw=true)<!--notebook-skip-line-->\n",
"\n",
"\n",
"**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n",
"\n",
"Lucky for us, the [Caltech101 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) has them, let's see how we can fine-tune our network to classify these categories correctly.\n",
"\n",
"We show that in our next tutorial:\n",
"\n",
"\n",
"- [Fine-tuning an ONNX Model using the modern imperative MXNet/Gluon](http://mxnet.apache.org/tutorials/onnx/fine_tuning_gluon.html)\n",
" \n",
"<!-- INSERT SOURCE DOWNLOAD BUTTONS -->"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}