blob: 714bdf46d2fb89a0149be94433a954b2bb255f90 [file] [log] [blame]
{"nbformat": 4, "cells": [{"source": "# Module - Neural network training and inference\n\nTraining a neural network involves quite a few steps. One need to specify how\nto feed input training data, initialize model parameters, perform forward and\nbackward passes through the network, update weights based on computed gradients, do\nmodel checkpoints, etc. During prediction, one ends up repeating most of these\nsteps. All this can be quite daunting to both newcomers as well as experienced\ndevelopers.\n\nLuckily, MXNet modularizes commonly used code for training and inference in\nthe `module` (`mod` for short) package. `Module` provides both high-level and\nintermediate-level interfaces for executing predefined networks. One can use\nboth interfaces interchangeably. We will show the usage of both interfaces in\nthis tutorial.\n\n## Prerequisites\n\nTo complete this tutorial, we need:\n\n- MXNet. See the instructions for your operating system in [Setup and Installation](http://mxnet.io/install/index.html). \n\n- [Jupyter Notebook](http://jupyter.org/index.html) and [Python Requests](http://docs.python-requests.org/en/master/) packages.\n```\npip install jupyter requests\n```\n\n## Preliminary\n\nIn this tutorial we will demonstrate `module` usage by training a\n[Multilayer Perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron) (MLP)\non the [UCI letter recognition](https://archive.ics.uci.edu/ml/datasets/letter+recognition)\ndataset.\n\nThe following code downloads the dataset and creates an 80:20 train:test\nsplit. It also initializes a training data iterator to return a batch of 32\ntraining examples each time. A separate iterator is also created for test data.", "cell_type": "markdown", "metadata": {}}, {"source": "import logging\nlogging.getLogger().setLevel(logging.INFO)\nimport mxnet as mx\nimport numpy as np\n\nmx.random.seed(1234)\nfname = mx.test_utils.download('http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data')\ndata = np.genfromtxt(fname, delimiter=',')[:,1:]\nlabel = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])\n\nbatch_size = 32\nntrain = int(data.shape[0]*0.8)\ntrain_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain], batch_size, shuffle=True)\nval_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:], batch_size)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Next, we define the network.", "cell_type": "markdown", "metadata": {}}, {"source": "net = mx.sym.Variable('data')\nnet = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)\nnet = mx.sym.Activation(net, name='relu1', act_type=\"relu\")\nnet = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)\nnet = mx.sym.SoftmaxOutput(net, name='softmax')\nmx.viz.plot_network(net)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "![svg](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/basic/module/output_3_0.svg?sanitize=true)\n\n\n\n## Creating a Module\n\nNow we are ready to introduce module. The commonly used module class is\n`Module`. We can construct a module by specifying the following parameters:\n\n- `symbol`: the network definition\n- `context`: the device (or a list of devices) to use for execution\n- `data_names` : the list of input data variable names\n- `label_names` : the list of input label variable names\n\nFor `net`, we have only one data named `data`, and one label named `softmax_label`,\nwhich is automatically named for us following the name `softmax` we specified for the `SoftmaxOutput` operator.", "cell_type": "markdown", "metadata": {}}, {"source": "mod = mx.mod.Module(symbol=net,\n context=mx.cpu(),\n data_names=['data'],\n label_names=['softmax_label'])", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Intermediate-level Interface\n\nWe have created module. Now let us see how to run training and inference using module's intermediate-level APIs. These APIs give developers flexibility to do step-by-step\ncomputation by running `forward` and `backward` passes. It's also useful for debugging.\n\nTo train a module, we need to perform following steps:\n\n- `bind` : Prepares environment for the computation by allocating memory.\n- `init_params` : Assigns and initializes parameters.\n- `init_optimizer` : Initializes optimizers. Defaults to `sgd`.\n- `metric.create` : Creates evaluation metric from input metric name.\n- `forward` : Forward computation.\n- `update_metric` : Evaluates and accumulates evaluation metric on outputs of the last forward computation.\n- `backward` : Backward computation.\n- `update` : Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch.\n\nThis can be used as follows:", "cell_type": "markdown", "metadata": {}}, {"source": "# allocate memory given the input data and label shapes\nmod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)\n# initialize parameters by uniform random numbers\nmod.init_params(initializer=mx.init.Uniform(scale=.1))\n# use SGD with learning rate 0.1 to train\nmod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))\n# use accuracy as the metric\nmetric = mx.metric.create('acc')\n# train 5 epochs, i.e. going over the data iter one pass\nfor epoch in range(5):\n train_iter.reset()\n metric.reset()\n for batch in train_iter:\n mod.forward(batch, is_train=True) # compute predictions\n mod.update_metric(metric, batch.label) # accumulate prediction accuracy\n mod.backward() # compute gradients\n mod.update() # update parameters\n print('Epoch %d, Training %s' % (epoch, metric.get()))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": " Epoch 0, Training ('accuracy', 0.4554375)\n Epoch 1, Training ('accuracy', 0.6485625)\n Epoch 2, Training ('accuracy', 0.7055625)\n Epoch 3, Training ('accuracy', 0.7396875)\n Epoch 4, Training ('accuracy', 0.764375)\n\n\nTo learn more about these APIs, visit [Module API](http://mxnet.io/api/python/module/module.html).\n\n## High-level Interface\n\n### Train\n\nModule also provides high-level APIs for training, predicting and evaluating for\nuser convenience. Instead of doing all the steps mentioned in the above section,\none can simply call [fit API](http://mxnet.io/api/python/module/module.html#mxnet.module.BaseModule.fit)\nand it internally executes the same steps.\n\nTo fit a module, call the `fit` function as follows:", "cell_type": "markdown", "metadata": {}}, {"source": "# reset train_iter to the beginning\ntrain_iter.reset()\n\n# create a module\nmod = mx.mod.Module(symbol=net,\n context=mx.cpu(),\n data_names=['data'],\n label_names=['softmax_label'])\n\n# fit the module\nmod.fit(train_iter,\n eval_data=val_iter,\n optimizer='sgd',\n optimizer_params={'learning_rate':0.1},\n eval_metric='acc',\n num_epoch=8)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": " INFO:root:Epoch[0] Train-accuracy=0.364625\n INFO:root:Epoch[0] Time cost=0.388\n INFO:root:Epoch[0] Validation-accuracy=0.557250\n INFO:root:Epoch[1] Train-accuracy=0.633625\n INFO:root:Epoch[1] Time cost=0.470\n INFO:root:Epoch[1] Validation-accuracy=0.634750\n INFO:root:Epoch[2] Train-accuracy=0.697187\n INFO:root:Epoch[2] Time cost=0.402\n INFO:root:Epoch[2] Validation-accuracy=0.665500\n INFO:root:Epoch[3] Train-accuracy=0.735062\n INFO:root:Epoch[3] Time cost=0.402\n INFO:root:Epoch[3] Validation-accuracy=0.713000\n INFO:root:Epoch[4] Train-accuracy=0.762563\n INFO:root:Epoch[4] Time cost=0.408\n INFO:root:Epoch[4] Validation-accuracy=0.742000\n INFO:root:Epoch[5] Train-accuracy=0.782312\n INFO:root:Epoch[5] Time cost=0.400\n INFO:root:Epoch[5] Validation-accuracy=0.778500\n INFO:root:Epoch[6] Train-accuracy=0.797188\n INFO:root:Epoch[6] Time cost=0.392\n INFO:root:Epoch[6] Validation-accuracy=0.798250\n INFO:root:Epoch[7] Train-accuracy=0.807750\n INFO:root:Epoch[7] Time cost=0.401\n INFO:root:Epoch[7] Validation-accuracy=0.789250\n\n\nBy default, `fit` function has `eval_metric` set to `accuracy`, `optimizer` to `sgd`\nand optimizer_params to `(('learning_rate', 0.01),)`.\n\n### Predict and Evaluate\n\nTo predict with module, we can call `predict()`. It will collect and\nreturn all the prediction results.", "cell_type": "markdown", "metadata": {}}, {"source": "y = mod.predict(val_iter)\nassert y.shape == (4000, 26)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "If we do not need the prediction outputs, but just need to evaluate on a test\nset, we can call the `score()` function. It runs prediction in the input validation\ndataset and evaluates the performance according to the given input metric.\n\nIt can be used as follows:", "cell_type": "markdown", "metadata": {}}, {"source": "score = mod.score(val_iter, ['acc'])\nprint(\"Accuracy score is %f\" % (score[0][1]))\nassert score[0][1] > 0.77, \"Achieved accuracy (%f) is less than expected (0.77)\" % score[0][1]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": " Accuracy score is 0.789250\n\n\nSome of the other metrics which can be used are `top_k_acc`(top-k-accuracy),\n`F1`, `RMSE`, `MSE`, `MAE`, `ce`(CrossEntropy). To learn more about the metrics,\nvisit [Evaluation metric](http://mxnet.io/api/python/metric/metric.html).\n\nOne can vary number of epochs, learning_rate, optimizer parameters to change the score\nand tune these parameters to get best score.\n\n### Save and Load\n\nWe can save the module parameters after each training epoch by using a checkpoint callback.", "cell_type": "markdown", "metadata": {}}, {"source": "# construct a callback function to save checkpoints\nmodel_prefix = 'mx_mlp'\ncheckpoint = mx.callback.do_checkpoint(model_prefix)\n\nmod = mx.mod.Module(symbol=net)\nmod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": " INFO:root:Epoch[0] Train-accuracy=0.101062\n INFO:root:Epoch[0] Time cost=0.422\n INFO:root:Saved checkpoint to \"mx_mlp-0001.params\"\n INFO:root:Epoch[1] Train-accuracy=0.263313\n INFO:root:Epoch[1] Time cost=0.785\n INFO:root:Saved checkpoint to \"mx_mlp-0002.params\"\n INFO:root:Epoch[2] Train-accuracy=0.452188\n INFO:root:Epoch[2] Time cost=0.624\n INFO:root:Saved checkpoint to \"mx_mlp-0003.params\"\n INFO:root:Epoch[3] Train-accuracy=0.544125\n INFO:root:Epoch[3] Time cost=0.427\n INFO:root:Saved checkpoint to \"mx_mlp-0004.params\"\n INFO:root:Epoch[4] Train-accuracy=0.605250\n INFO:root:Epoch[4] Time cost=0.399\n INFO:root:Saved checkpoint to \"mx_mlp-0005.params\"\n\n\nTo load the saved module parameters, call the `load_checkpoint` function. It\nloads the Symbol and the associated parameters. We can then set the loaded\nparameters into the module.", "cell_type": "markdown", "metadata": {}}, {"source": "sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)\nassert sym.tojson() == net.tojson()\n\n# assign the loaded parameters to the module\nmod.set_params(arg_params, aux_params)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Or if we just want to resume training from a saved checkpoint, instead of\ncalling `set_params()`, we can directly call `fit()`, passing the loaded\nparameters, so that `fit()` knows to start from those parameters instead of\ninitializing randomly from scratch. We also set the `begin_epoch` parameter so that\n`fit()` knows we are resuming from a previously saved epoch.", "cell_type": "markdown", "metadata": {}}, {"source": "mod = mx.mod.Module(symbol=sym)\nmod.fit(train_iter,\n num_epoch=21,\n arg_params=arg_params,\n aux_params=aux_params,\n begin_epoch=3)\nassert score[0][1] > 0.77, \"Achieved accuracy (%f) is less than expected (0.77)\" % score[0][1] ", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n INFO:root:Epoch[3] Train-accuracy=0.544125\n INFO:root:Epoch[3] Time cost=0.398\n INFO:root:Epoch[4] Train-accuracy=0.605250\n INFO:root:Epoch[4] Time cost=0.545\n INFO:root:Epoch[5] Train-accuracy=0.644312\n INFO:root:Epoch[5] Time cost=0.592\n INFO:root:Epoch[6] Train-accuracy=0.675000\n INFO:root:Epoch[6] Time cost=0.491\n INFO:root:Epoch[7] Train-accuracy=0.695812\n INFO:root:Epoch[7] Time cost=0.363\n\n\n\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}