blob: 17f570682ba308fa3edc3aa8f2f95ef979bffcc1 [file] [log] [blame]
{"nbformat": 4, "cells": [{"source": "# Predict with pre-trained models\n\nThis tutorial explains how to recognize objects in an image with a\npre-trained model, and how to perform feature extraction.\n\n## Prerequisites\n\nTo complete this tutorial, we need:\n\n- MXNet. See the instructions for your operating system in [Setup and Installation](http://mxnet.io/get_started/install.html)\n\n- [Python Requests](http://docs.python-requests.org/en/master/), [Matplotlib](https://matplotlib.org/) and [Jupyter Notebook](http://jupyter.org/index.html).\n\n```\n$ pip install requests matplotlib jupyter opencv-python\n```\n\n## Loading\n\nWe first download a pre-trained ResNet 152 layer that is trained on the full\nImageNet dataset with over 10 million images and 10 thousand classes. A\npre-trained model contains two parts, a json file containing the model\ndefinition and a binary file containing the parameters. In addition, there may be\na text file for the labels.", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\npath='http://data.mxnet.io/models/imagenet-11k/'\n[mx.test_utils.download(path+'resnet-152/resnet-152-symbol.json'),\n mx.test_utils.download(path+'resnet-152/resnet-152-0000.params'),\n mx.test_utils.download(path+'synset.txt')]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Next, we load the downloaded model. *Note:* If GPU is available, we can replace all\noccurrences of `mx.cpu()` with `mx.gpu()` to accelerate the computation.", "cell_type": "markdown", "metadata": {}}, {"source": "sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)\nmod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)\nmod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], \n label_shapes=mod._label_shapes)\nmod.set_params(arg_params, aux_params, allow_missing=True)\nwith open('synset.txt', 'r') as f:\n labels = [l.rstrip() for l in f]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Predicting\n\nWe first define helper functions for downloading an image and performing the\nprediction:", "cell_type": "markdown", "metadata": {}}, {"source": "%matplotlib inline\nimport matplotlib.pyplot as plt\nimport cv2\nimport numpy as np\n# define a simple data batch\nfrom collections import namedtuple\nBatch = namedtuple('Batch', ['data'])\n\ndef get_image(url, show=False):\n # download and show the image\n fname = mx.test_utils.download(url)\n img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)\n if img is None:\n return None\n if show:\n plt.imshow(img)\n plt.axis('off')\n # convert into format (batch, RGB, width, height)\n img = cv2.resize(img, (224, 224))\n img = np.swapaxes(img, 0, 2)\n img = np.swapaxes(img, 1, 2)\n img = img[np.newaxis, :]\n return img\n\ndef predict(url):\n img = get_image(url, show=True)\n # compute the predict probabilities\n mod.forward(Batch([mx.nd.array(img)]))\n prob = mod.get_outputs()[0].asnumpy()\n # print the top-5\n prob = np.squeeze(prob)\n a = np.argsort(prob)[::-1]\n for i in a[0:5]:\n print('probability=%f, class=%s' %(prob[i], labels[i]))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Now, we can perform prediction with any downloadable URL:", "cell_type": "markdown", "metadata": {}}, {"source": "predict('http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "predict('http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg')", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Feature extraction\n\nBy feature extraction, we mean presenting the input images by the output of an\ninternal layer rather than the last softmax layer. These outputs, which can be\nviewed as the feature of the raw input image, can then be used by other\napplications such as object detection.\n\nWe can use the ``get_internals`` method to get all internal layers from a\nSymbol.", "cell_type": "markdown", "metadata": {}}, {"source": "# list the last 10 layers\nall_layers = sym.get_internals()\nall_layers.list_outputs()[-10:]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "An often used layer for feature extraction is the one before the last fully\nconnected layer. For ResNet, and also Inception, it is the flattened layer with\nname `flatten0` which reshapes the 4-D convolutional layer output into 2-D for\nthe fully connected layer. The following source code extracts a new Symbol which\noutputs the flattened layer and creates a model.", "cell_type": "markdown", "metadata": {}}, {"source": "fe_sym = all_layers['flatten0_output']\nfe_mod = mx.mod.Module(symbol=fe_sym, context=mx.cpu(), label_names=None)\nfe_mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])\nfe_mod.set_params(arg_params, aux_params)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can now invoke `forward` to obtain the features:", "cell_type": "markdown", "metadata": {}}, {"source": "img = get_image('http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')\nfe_mod.forward(Batch([mx.nd.array(img)]))\nfeatures = fe_mod.get_outputs()[0].asnumpy()\nprint(features)\nassert features.shape == (1, 2048)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}