| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "954625fb", |
| "metadata": {}, |
| "source": [ |
| "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n", |
| "<!--- or more contributor license agreements. See the NOTICE file -->\n", |
| "<!--- distributed with this work for additional information -->\n", |
| "<!--- regarding copyright ownership. The ASF licenses this file -->\n", |
| "<!--- to you under the Apache License, Version 2.0 (the -->\n", |
| "<!--- \"License\"); you may not use this file except in compliance -->\n", |
| "<!--- with the License. You may obtain a copy of the License at -->\n", |
| "\n", |
| "<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n", |
| "\n", |
| "<!--- Unless required by applicable law or agreed to in writing, -->\n", |
| "<!--- software distributed under the License is distributed on an -->\n", |
| "<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n", |
| "<!--- KIND, either express or implied. See the License for the -->\n", |
| "<!--- specific language governing permissions and limitations -->\n", |
| "<!--- under the License. -->\n", |
| "\n", |
| "\n", |
| "# Running inference on MXNet/Gluon from an ONNX model\n", |
| "\n", |
| "[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n", |
| "\n", |
| "In this tutorial we will:\n", |
| "\n", |
| "- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n", |
| "- learn how to test this model using the sample input/output\n", |
| "- learn how to test the model on custom images\n", |
| "\n", |
| "## Pre-requisite\n", |
| "\n", |
| "To run the tutorial you will need to have installed the following python modules:\n", |
| "- [MXNet > 1.1.0](https://mxnet.apache.org/get_started)\n", |
| "- [onnx](https://github.com/onnx/onnx) (follow the install guide)\n", |
| "- matplotlib" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "4a4dc4d4", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import numpy as np\n", |
| "import mxnet as mx\n", |
| "from mxnet.contrib import onnx as onnx_mxnet\n", |
| "from mxnet import gluon, nd\n", |
| "%matplotlib inline\n", |
| "import matplotlib.pyplot as plt\n", |
| "import tarfile, os\n", |
| "import json\n", |
| "import logging\n", |
| "logging.basicConfig(level=logging.INFO)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "c8938f9e", |
| "metadata": {}, |
| "source": [ |
| "### Downloading supporting files\n", |
| "These are images and a vizualisation script" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "e749c9c9", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "image_folder = \"images\"\n", |
| "utils_file = \"utils.py\" # contain utils function to plot nice visualization\n", |
| "image_net_labels_file = \"image_net_labels.json\"\n", |
| "images = ['apron.jpg', 'hammerheadshark.jpg', 'dog.jpg', 'wrench.jpg', 'dolphin.jpg', 'lotus.jpg']\n", |
| "base_url = \"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/{}?raw=true\"\n", |
| "\n", |
| "for image in images:\n", |
| " mx.test_utils.download(base_url.format(\"{}/{}\".format(image_folder, image)), fname=image,dirname=image_folder)\n", |
| "mx.test_utils.download(base_url.format(utils_file), fname=utils_file)\n", |
| "mx.test_utils.download(base_url.format(image_net_labels_file), fname=image_net_labels_file)\n", |
| "\n", |
| "from utils import *" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4518dbdf", |
| "metadata": {}, |
| "source": [ |
| "## Downloading a model from the ONNX model zoo\n", |
| "\n", |
| "We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "37b2df22", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "base_url = \"https://s3.amazonaws.com/download.onnx/models/opset_3/\"\n", |
| "current_model = \"bvlc_googlenet\"\n", |
| "model_folder = \"model\"\n", |
| "archive = \"{}.tar.gz\".format(current_model)\n", |
| "archive_file = os.path.join(model_folder, archive)\n", |
| "url = \"{}{}\".format(base_url, archive)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "23e29ce7", |
| "metadata": {}, |
| "source": [ |
| "Download and extract pre-trained model" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "10cf2e1e", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "mx.test_utils.download(url, dirname = model_folder)\n", |
| "if not os.path.isdir(os.path.join(model_folder, current_model)):\n", |
| " print('Extracting model...')\n", |
| " tar = tarfile.open(archive_file, \"r:gz\")\n", |
| " tar.extractall(model_folder)\n", |
| " tar.close()\n", |
| " print('Extracted')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ac55756f", |
| "metadata": {}, |
| "source": [ |
| "The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "e93ed870", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "categories = json.load(open(image_net_labels_file, 'r'))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "aa43b716", |
| "metadata": {}, |
| "source": [ |
| "## Loading the model into MXNet Gluon" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "9e4c0cae", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4fb3ef07", |
| "metadata": {}, |
| "source": [ |
| "We get the symbol and parameter objects" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "c9c13e57", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "e6908285", |
| "metadata": {}, |
| "source": [ |
| "We pick a device, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "c08c7b5d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "device = mx.cpu()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ae61bb3a", |
| "metadata": {}, |
| "source": [ |
| "We obtain the data names of the inputs to the model by using the model metadata API:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "635c8818", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "model_metadata = onnx_mxnet.get_model_metadata(onnx_path)\n", |
| "print(model_metadata)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "af7810b8", |
| "metadata": {}, |
| "source": [ |
| "```\n", |
| "{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],\n", |
| " 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "9cfdac34", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]\n", |
| "print(data_names)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a8a3a8aa", |
| "metadata": {}, |
| "source": [ |
| "And load them into a MXNet Gluon symbol block." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "80cd0957", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import warnings\n", |
| "with warnings.catch_warnings():\n", |
| " warnings.simplefilter(\"ignore\")\n", |
| " net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))\n", |
| "net_params = net.collect_params()\n", |
| "for param in arg_params:\n", |
| " if param in net_params:\n", |
| " net_params[param]._load_init(arg_params[param], device=device)\n", |
| "for param in aux_params:\n", |
| " if param in net_params:\n", |
| " net_params[param]._load_init(aux_params[param], device=device)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "069d723d", |
| "metadata": {}, |
| "source": [ |
| "We can now cache the computational graph through [hybridization](https://mxnet.apache.org/versions/master/api/python/docs/tutorials/packages/gluon/blocks/hybridize.html) to gain some performance" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "314dafe3", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "net.hybridize()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "91aeaba8", |
| "metadata": {}, |
| "source": [ |
| "We can visualize the network (requires graphviz installed)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "65f60def", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "mx.visualization.plot_network(sym, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ace59190", |
| "metadata": {}, |
| "source": [ |
| "![network2](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/network2.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "\n", |
| "This is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "d424af63", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def run_batch(net, data):\n", |
| " results = []\n", |
| " for batch in data:\n", |
| " outputs = net(batch)\n", |
| " results.extend([o for o in outputs.asnumpy()])\n", |
| " return np.array(results)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3d51763b", |
| "metadata": {}, |
| "source": [ |
| "## Test using real images" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "27e59f47", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "TOP_P = 3 # How many top guesses we show in the visualization" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4519d05d", |
| "metadata": {}, |
| "source": [ |
| "Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "acd10b01", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def transform(img):\n", |
| " return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "59f29daa", |
| "metadata": {}, |
| "source": [ |
| "We load two sets of images in memory" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "90f0040f", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "image_net_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['apron', 'hammerheadshark','dog']]\n", |
| "caltech101_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['wrench', 'dolphin','lotus']]\n", |
| "images = image_net_images + caltech101_images" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b7306dfb", |
| "metadata": {}, |
| "source": [ |
| "And run them as a batch through the network to get the predictions" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "8eb1aa79", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), device=device)\n", |
| "result = run_batch(net, [batch])" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "4fee3e6b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "plot_predictions(image_net_images, result[:3], categories, TOP_P)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "4643d505", |
| "metadata": {}, |
| "source": [ |
| "![imagenet](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/imagenet.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n", |
| "\n", |
| "Let's now see the results on the 3 other images" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "8f6d1ee8", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "plot_predictions(caltech101_images, result[3:7], categories, TOP_P)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "110f4ef4", |
| "metadata": {}, |
| "source": [ |
| "![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/caltech101.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n", |
| "\n", |
| "Lucky for us, the [Caltech101 dataset](https://data.caltech.edu/records/20086) has them, let's see how we can fine-tune our network to classify these categories correctly.\n", |
| "\n", |
| "We show that in our next tutorial:\n", |
| "\n", |
| "\n", |
| "- [Fine-tuning an ONNX Model using the modern imperative MXNet/Gluon](https://mxnet.apache.org/versions/master/api/python/docs/tutorials/packages/onnx/fine_tuning_gluon.html)\n", |
| "\n", |
| "<!-- INSERT SOURCE DOWNLOAD BUTTONS -->" |
| ] |
| } |
| ], |
| "metadata": { |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |