| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "8872dc38", |
| "metadata": {}, |
| "source": [ |
| "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n", |
| "<!--- or more contributor license agreements. See the NOTICE file -->\n", |
| "<!--- distributed with this work for additional information -->\n", |
| "<!--- regarding copyright ownership. The ASF licenses this file -->\n", |
| "<!--- to you under the Apache License, Version 2.0 (the -->\n", |
| "<!--- \"License\"); you may not use this file except in compliance -->\n", |
| "<!--- with the License. You may obtain a copy of the License at -->\n", |
| "\n", |
| "<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n", |
| "\n", |
| "<!--- Unless required by applicable law or agreed to in writing, -->\n", |
| "<!--- software distributed under the License is distributed on an -->\n", |
| "<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n", |
| "<!--- KIND, either express or implied. See the License for the -->\n", |
| "<!--- specific language governing permissions and limitations -->\n", |
| "<!--- under the License. -->\n", |
| "\n", |
| "\n", |
| "# Running inference on MXNet/Gluon from an ONNX model\n", |
| "\n", |
| "[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n", |
| "\n", |
| "In this tutorial we will:\n", |
| " \n", |
| "- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n", |
| "- learn how to test this model using the sample input/output\n", |
| "- learn how to test the model on custom images\n", |
| "\n", |
| "## Pre-requisite\n", |
| "\n", |
| "To run the tutorial you will need to have installed the following python modules:\n", |
| "- [MXNet > 1.1.0](/get_started)\n", |
| "- [onnx](https://github.com/onnx/onnx) (follow the install guide)\n", |
| "- matplotlib" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "af440272", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "import numpy as np\n", |
| "import mxnet as mx\n", |
| "from mxnet.contrib import onnx as onnx_mxnet\n", |
| "from mxnet import gluon, nd\n", |
| "%matplotlib inline\n", |
| "import matplotlib.pyplot as plt\n", |
| "import tarfile, os\n", |
| "import json\n", |
| "import logging\n", |
| "logging.basicConfig(level=logging.INFO)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "16c099c8", |
| "metadata": {}, |
| "source": [ |
| "### Downloading supporting files\n", |
| "These are images and a vizualisation script" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "853f948e", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "image_folder = \"images\"\n", |
| "utils_file = \"utils.py\" # contain utils function to plot nice visualization\n", |
| "image_net_labels_file = \"image_net_labels.json\"\n", |
| "images = ['apron.jpg', 'hammerheadshark.jpg', 'dog.jpg', 'wrench.jpg', 'dolphin.jpg', 'lotus.jpg']\n", |
| "base_url = \"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/{}?raw=true\"\n", |
| "\n", |
| "for image in images:\n", |
| " mx.test_utils.download(base_url.format(\"{}/{}\".format(image_folder, image)), fname=image,dirname=image_folder)\n", |
| "mx.test_utils.download(base_url.format(utils_file), fname=utils_file)\n", |
| "mx.test_utils.download(base_url.format(image_net_labels_file), fname=image_net_labels_file)\n", |
| "\n", |
| "from utils import * \n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "cfaa26c9", |
| "metadata": {}, |
| "source": [ |
| "## Downloading a model from the ONNX model zoo\n", |
| "\n", |
| "We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a84c6b93", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "base_url = \"https://s3.amazonaws.com/download.onnx/models/opset_3/\" \n", |
| "current_model = \"bvlc_googlenet\"\n", |
| "model_folder = \"model\"\n", |
| "archive = \"{}.tar.gz\".format(current_model)\n", |
| "archive_file = os.path.join(model_folder, archive)\n", |
| "url = \"{}{}\".format(base_url, archive)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "bbee36b3", |
| "metadata": {}, |
| "source": [ |
| "Download and extract pre-trained model" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "931ca57f", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "mx.test_utils.download(url, dirname = model_folder)\n", |
| "if not os.path.isdir(os.path.join(model_folder, current_model)):\n", |
| " print('Extracting model...')\n", |
| " tar = tarfile.open(archive_file, \"r:gz\")\n", |
| " tar.extractall(model_folder)\n", |
| " tar.close()\n", |
| " print('Extracted')\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b2ed6ac6", |
| "metadata": {}, |
| "source": [ |
| "The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "0f25024d", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "categories = json.load(open(image_net_labels_file, 'r'))\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "54dae113", |
| "metadata": {}, |
| "source": [ |
| "## Loading the model into MXNet Gluon" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a49e50de", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d8ef89ee", |
| "metadata": {}, |
| "source": [ |
| "We get the symbol and parameter objects" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "32cbad35", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ff1870f4", |
| "metadata": {}, |
| "source": [ |
| "We pick a context, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "96cb6331", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "ctx = mx.cpu()\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "bce3a1e9", |
| "metadata": {}, |
| "source": [ |
| "We obtain the data names of the inputs to the model by using the model metadata API:" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d150b3fb", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "model_metadata = onnx_mxnet.get_model_metadata(onnx_path)\n", |
| "print(model_metadata)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a3ee3468", |
| "metadata": {}, |
| "source": [ |
| "```\n", |
| "{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],\n", |
| " 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1c10e556", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]\n", |
| "print(data_names)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1135a7c3", |
| "metadata": {}, |
| "source": [ |
| "```[u'data_0']```<!--notebook-skip-line-->\n", |
| "\n", |
| "And load them into a MXNet Gluon symbol block. \n", |
| "\n", |
| "```python\n", |
| "import warnings\n", |
| "with warnings.catch_warnings():\n", |
| " warnings.simplefilter(\"ignore\")\n", |
| " net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))\n", |
| "net_params = net.collect_params()\n", |
| "for param in arg_params:\n", |
| " if param in net_params:\n", |
| " net_params[param]._load_init(arg_params[param], ctx=ctx)\n", |
| "for param in aux_params:\n", |
| " if param in net_params:\n", |
| " net_params[param]._load_init(aux_params[param], ctx=ctx)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "73855938", |
| "metadata": {}, |
| "source": [ |
| "We can now cache the computational graph through [hybridization](https://mxnet.apache.org/tutorials/gluon/hybrid.html) to gain some performance" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "775d1435", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "net.hybridize()\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "e1bd1dbf", |
| "metadata": {}, |
| "source": [ |
| "We can visualize the network (requires graphviz installed)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b377c080", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "mx.visualization.plot_network(sym, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a293ee5e", |
| "metadata": {}, |
| "source": [ |
| "![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/network2.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "\n", |
| "This is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f71ed3f5", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "def run_batch(net, data):\n", |
| " results = []\n", |
| " for batch in data:\n", |
| " outputs = net(batch)\n", |
| " results.extend([o for o in outputs.asnumpy()])\n", |
| " return np.array(results)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b4472f0e", |
| "metadata": {}, |
| "source": [ |
| "## Test using real images" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f144c595", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "TOP_P = 3 # How many top guesses we show in the visualization\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d1b6daf4", |
| "metadata": {}, |
| "source": [ |
| "Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a3b94eed", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "def transform(img):\n", |
| " return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "485ff49e", |
| "metadata": {}, |
| "source": [ |
| "We load two sets of images in memory" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d1781e7c", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "image_net_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['apron', 'hammerheadshark','dog']]\n", |
| "caltech101_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['wrench', 'dolphin','lotus']]\n", |
| "images = image_net_images + caltech101_images\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "096d6b2d", |
| "metadata": {}, |
| "source": [ |
| "And run them as a batch through the network to get the predictions" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1d4dff6a", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), ctx=ctx)\n", |
| "result = run_batch(net, [batch])\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "abf0eb33", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "plot_predictions(image_net_images, result[:3], categories, TOP_P)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "5206cc61", |
| "metadata": {}, |
| "source": [ |
| "![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/imagenet.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n", |
| "\n", |
| "Let's now see the results on the 3 other images" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ea51bc0a", |
| "metadata": {}, |
| "source": [ |
| "```python\n", |
| "plot_predictions(caltech101_images, result[3:7], categories, TOP_P)\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7974166d", |
| "metadata": {}, |
| "source": [ |
| "![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/caltech101.png?raw=true)<!--notebook-skip-line-->\n", |
| "\n", |
| "\n", |
| "**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n", |
| "\n", |
| "Lucky for us, the [Caltech101 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) has them, let's see how we can fine-tune our network to classify these categories correctly.\n", |
| "\n", |
| "We show that in our next tutorial:\n", |
| "\n", |
| "\n", |
| "- [Fine-tuning an ONNX Model using the modern imperative MXNet/Gluon](http://mxnet.apache.org/tutorials/onnx/fine_tuning_gluon.html)\n", |
| " \n", |
| "<!-- INSERT SOURCE DOWNLOAD BUTTONS -->" |
| ] |
| } |
| ], |
| "metadata": { |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |