blob: de1f3e132910ae6810d1ee93c4d1277b0076aa64 [file] [log] [blame]
"cells": [
"cell_type": "markdown",
"id": "8872dc38",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"<!--- -->\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"# Running inference on MXNet/Gluon from an ONNX model\n",
"[Open Neural Network Exchange (ONNX)]( provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n",
"In this tutorial we will:\n",
" \n",
"- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n",
"- learn how to test this model using the sample input/output\n",
"- learn how to test the model on custom images\n",
"## Pre-requisite\n",
"To run the tutorial you will need to have installed the following python modules:\n",
"- [MXNet > 1.1.0](/get_started)\n",
"- [onnx]( (follow the install guide)\n",
"- matplotlib"
"cell_type": "markdown",
"id": "af440272",
"metadata": {},
"source": [
"import numpy as np\n",
"import mxnet as mx\n",
"from mxnet.contrib import onnx as onnx_mxnet\n",
"from mxnet import gluon, nd\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import tarfile, os\n",
"import json\n",
"import logging\n",
"cell_type": "markdown",
"id": "16c099c8",
"metadata": {},
"source": [
"### Downloading supporting files\n",
"These are images and a vizualisation script"
"cell_type": "markdown",
"id": "853f948e",
"metadata": {},
"source": [
"image_folder = \"images\"\n",
"utils_file = \"\" # contain utils function to plot nice visualization\n",
"image_net_labels_file = \"image_net_labels.json\"\n",
"images = ['apron.jpg', 'hammerheadshark.jpg', 'dog.jpg', 'wrench.jpg', 'dolphin.jpg', 'lotus.jpg']\n",
"base_url = \"{}?raw=true\"\n",
"for image in images:\n",
"\"{}/{}\".format(image_folder, image)), fname=image,dirname=image_folder)\n",
", fname=utils_file)\n",
", fname=image_net_labels_file)\n",
"from utils import * \n",
"cell_type": "markdown",
"id": "cfaa26c9",
"metadata": {},
"source": [
"## Downloading a model from the ONNX model zoo\n",
"We download a pre-trained model, in our case the [GoogleNet]( model, trained on [ImageNet]( from the [ONNX model zoo]( The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file."
"cell_type": "markdown",
"id": "a84c6b93",
"metadata": {},
"source": [
"base_url = \"\" \n",
"current_model = \"bvlc_googlenet\"\n",
"model_folder = \"model\"\n",
"archive = \"{}.tar.gz\".format(current_model)\n",
"archive_file = os.path.join(model_folder, archive)\n",
"url = \"{}{}\".format(base_url, archive)\n",
"cell_type": "markdown",
"id": "bbee36b3",
"metadata": {},
"source": [
"Download and extract pre-trained model"
"cell_type": "markdown",
"id": "931ca57f",
"metadata": {},
"source": [
", dirname = model_folder)\n",
"if not os.path.isdir(os.path.join(model_folder, current_model)):\n",
" print('Extracting model...')\n",
" tar =, \"r:gz\")\n",
" tar.extractall(model_folder)\n",
" tar.close()\n",
" print('Extracted')\n",
"cell_type": "markdown",
"id": "b2ed6ac6",
"metadata": {},
"source": [
"The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes."
"cell_type": "markdown",
"id": "0f25024d",
"metadata": {},
"source": [
"categories = json.load(open(image_net_labels_file, 'r'))\n",
"cell_type": "markdown",
"id": "54dae113",
"metadata": {},
"source": [
"## Loading the model into MXNet Gluon"
"cell_type": "markdown",
"id": "a49e50de",
"metadata": {},
"source": [
"onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")\n",
"cell_type": "markdown",
"id": "d8ef89ee",
"metadata": {},
"source": [
"We get the symbol and parameter objects"
"cell_type": "markdown",
"id": "32cbad35",
"metadata": {},
"source": [
"sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)\n",
"cell_type": "markdown",
"id": "ff1870f4",
"metadata": {},
"source": [
"We pick a context, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU."
"cell_type": "markdown",
"id": "96cb6331",
"metadata": {},
"source": [
"ctx = mx.cpu()\n",
"cell_type": "markdown",
"id": "bce3a1e9",
"metadata": {},
"source": [
"We obtain the data names of the inputs to the model by using the model metadata API:"
"cell_type": "markdown",
"id": "d150b3fb",
"metadata": {},
"source": [
"model_metadata = onnx_mxnet.get_model_metadata(onnx_path)\n",
"cell_type": "markdown",
"id": "a3ee3468",
"metadata": {},
"source": [
"{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],\n",
" 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}\n",
"cell_type": "markdown",
"id": "1c10e556",
"metadata": {},
"source": [
"data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]\n",
"cell_type": "markdown",
"id": "1135a7c3",
"metadata": {},
"source": [
"And load them into a MXNet Gluon symbol block. \n",
"import warnings\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))\n",
"net_params = net.collect_params()\n",
"for param in arg_params:\n",
" if param in net_params:\n",
" net_params[param]._load_init(arg_params[param], ctx=ctx)\n",
"for param in aux_params:\n",
" if param in net_params:\n",
" net_params[param]._load_init(aux_params[param], ctx=ctx)\n",
"cell_type": "markdown",
"id": "73855938",
"metadata": {},
"source": [
"We can now cache the computational graph through [hybridization]( to gain some performance"
"cell_type": "markdown",
"id": "775d1435",
"metadata": {},
"source": [
"cell_type": "markdown",
"id": "e1bd1dbf",
"metadata": {},
"source": [
"We can visualize the network (requires graphviz installed)"
"cell_type": "markdown",
"id": "b377c080",
"metadata": {},
"source": [
"mx.visualization.plot_network(sym, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})\n",
"cell_type": "markdown",
"id": "a293ee5e",
"metadata": {},
"source": [
"This is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network."
"cell_type": "markdown",
"id": "f71ed3f5",
"metadata": {},
"source": [
"def run_batch(net, data):\n",
" results = []\n",
" for batch in data:\n",
" outputs = net(batch)\n",
" results.extend([o for o in outputs.asnumpy()])\n",
" return np.array(results)\n",
"cell_type": "markdown",
"id": "b4472f0e",
"metadata": {},
"source": [
"## Test using real images"
"cell_type": "markdown",
"id": "f144c595",
"metadata": {},
"source": [
"TOP_P = 3 # How many top guesses we show in the visualization\n",
"cell_type": "markdown",
"id": "d1b6daf4",
"metadata": {},
"source": [
"Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size."
"cell_type": "markdown",
"id": "a3b94eed",
"metadata": {},
"source": [
"def transform(img):\n",
" return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)\n",
"cell_type": "markdown",
"id": "485ff49e",
"metadata": {},
"source": [
"We load two sets of images in memory"
"cell_type": "markdown",
"id": "d1781e7c",
"metadata": {},
"source": [
"image_net_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['apron', 'hammerheadshark','dog']]\n",
"caltech101_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['wrench', 'dolphin','lotus']]\n",
"images = image_net_images + caltech101_images\n",
"cell_type": "markdown",
"id": "096d6b2d",
"metadata": {},
"source": [
"And run them as a batch through the network to get the predictions"
"cell_type": "markdown",
"id": "1d4dff6a",
"metadata": {},
"source": [
"batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), ctx=ctx)\n",
"result = run_batch(net, [batch])\n",
"cell_type": "markdown",
"id": "abf0eb33",
"metadata": {},
"source": [
"plot_predictions(image_net_images, result[:3], categories, TOP_P)\n",
"cell_type": "markdown",
"id": "5206cc61",
"metadata": {},
"source": [
"**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n",
"Let's now see the results on the 3 other images"
"cell_type": "markdown",
"id": "ea51bc0a",
"metadata": {},
"source": [
"plot_predictions(caltech101_images, result[3:7], categories, TOP_P)\n",
"cell_type": "markdown",
"id": "7974166d",
"metadata": {},
"source": [
"**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n",
"Lucky for us, the [Caltech101 dataset]( has them, let's see how we can fine-tune our network to classify these categories correctly.\n",
"We show that in our next tutorial:\n",
"- [Fine-tuning an ONNX Model using the modern imperative MXNet/Gluon](\n",
" \n",
"metadata": {
"language_info": {
"name": "python"
"nbformat": 4,
"nbformat_minor": 5