blob: ec4aef4eee99553fb92f579b5fe8167eec06c4d1 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "c7e7e094",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Importing an ONNX model into MXNet\n",
"\n",
"In this tutorial we will:\n",
"\n",
"- learn how to load a pre-trained ONNX model file into MXNet.\n",
"- run inference in MXNet.\n",
"\n",
"## Prerequisites\n",
"This example assumes that the following python packages are installed:\n",
"- [mxnet](/get_started)\n",
"- [onnx](https://github.com/onnx/onnx) (follow the install guide)\n",
"- Pillow - A Python Image Processing package and is required for input pre-processing. It can be installed with ```pip install Pillow```.\n",
"- matplotlib"
]
},
{
"cell_type": "markdown",
"id": "d7cc71fc",
"metadata": {},
"source": [
"```python\n",
"from PIL import Image\n",
"import numpy as np\n",
"import mxnet as mx\n",
"import mxnet.contrib.onnx as onnx_mxnet\n",
"from mxnet.test_utils import download\n",
"from matplotlib.pyplot import imshow\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "decf8d87",
"metadata": {},
"source": [
"### Fetching the required files"
]
},
{
"cell_type": "markdown",
"id": "d48b9dbc",
"metadata": {},
"source": [
"```python\n",
"img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'\n",
"download(img_url, 'super_res_input.jpg')\n",
"model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'\n",
"onnx_model_file = download(model_url, 'super_resolution.onnx')\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "5ee632aa",
"metadata": {},
"source": [
"## Loading the model into MXNet\n",
"\n",
"To completely describe a pre-trained model in MXNet, we need two elements: a symbolic graph, containing the model's network definition, and a binary file containing the model weights. You can import the ONNX model and get the symbol and parameters objects using ``import_model`` API. The paameter object is split into argument parameters and auxilliary parameters."
]
},
{
"cell_type": "markdown",
"id": "89e9523d",
"metadata": {},
"source": [
"```python\n",
"sym, arg, aux = onnx_mxnet.import_model(onnx_model_file)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "04e270e3",
"metadata": {},
"source": [
"We can now visualize the imported model (graphviz needs to be installed)"
]
},
{
"cell_type": "markdown",
"id": "7e4fba75",
"metadata": {},
"source": [
"```python\n",
"mx.viz.plot_network(sym, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "d6118196",
"metadata": {},
"source": [
"![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) <!--notebook-skip-line-->\n",
"\n",
"\n",
"\n",
"## Input Pre-processing\n",
"\n",
"We will transform the previously downloaded input image into an input tensor."
]
},
{
"cell_type": "markdown",
"id": "49dd891b",
"metadata": {},
"source": [
"```python\n",
"img = Image.open('super_res_input.jpg').resize((224, 224))\n",
"img_ycbcr = img.convert(\"YCbCr\")\n",
"img_y, img_cb, img_cr = img_ycbcr.split()\n",
"test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "5fd422cd",
"metadata": {},
"source": [
"## Run Inference using MXNet's Module API\n",
"\n",
"We will use MXNet's Module API to run the inference. For this we will need to create the module, bind it to the input data and assign the loaded weights from the two parameter objects - argument parameters and auxilliary parameters.\n",
"\n",
"To obtain the input data names we run the following line, which picks all the inputs of the symbol graph excluding the argument and auxiliary parameters:"
]
},
{
"cell_type": "markdown",
"id": "608c51cd",
"metadata": {},
"source": [
"```python\n",
"data_names = [graph_input for graph_input in sym.list_inputs()\n",
" if graph_input not in arg and graph_input not in aux]\n",
"print(data_names)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "832f7c1a",
"metadata": {},
"source": [
"```['1']```\n",
"\n",
"```python\n",
"mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)\n",
"mod.bind(for_training=False, data_shapes=[(data_names[0],test_image.shape)], label_shapes=None)\n",
"mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "5d4a97e0",
"metadata": {},
"source": [
"Module API's forward method requires batch of data as input. We will prepare the data in that format and feed it to the forward method."
]
},
{
"cell_type": "markdown",
"id": "06cee5e2",
"metadata": {},
"source": [
"```python\n",
"from collections import namedtuple\n",
"Batch = namedtuple('Batch', ['data'])\n",
"\n",
"# forward on the provided data batch\n",
"mod.forward(Batch([mx.nd.array(test_image)]))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "99ddca6d",
"metadata": {},
"source": [
"To get the output of previous forward computation, you use ``module.get_outputs()`` method.\n",
"It returns an ``ndarray`` that we convert to a ``numpy`` array and then to Pillow's image format"
]
},
{
"cell_type": "markdown",
"id": "7767d796",
"metadata": {},
"source": [
"```python\n",
"output = mod.get_outputs()[0][0][0]\n",
"img_out_y = Image.fromarray(np.uint8((output.asnumpy().clip(0, 255)), mode='L'))\n",
"result_img = Image.merge(\n",
"\"YCbCr\", [\n",
" img_out_y,\n",
" img_cb.resize(img_out_y.size, Image.BICUBIC),\n",
" img_cr.resize(img_out_y.size, Image.BICUBIC)\n",
"]).convert(\"RGB\")\n",
"result_img.save(\"super_res_output.jpg\")\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "b662664d",
"metadata": {},
"source": [
"You can now compare the input image and the resulting output image. As you will notice, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``.\n",
"\n",
"| Input Image | Output Image | <!--notebook-skip-line-->\n",
"| ----------- | ------------ | <!--notebook-skip-line-->\n",
"| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) | <!--notebook-skip-line-->\n",
"\n",
"<!-- INSERT SOURCE DOWNLOAD BUTTONS -->"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}