| |
| |
| |
| |
| <!doctype html> |
| <html lang="en" class="no-js"> |
| <head> |
| |
| <meta charset="utf-8"> |
| <meta name="viewport" content="width=device-width,initial-scale=1"> |
| <meta http-equiv="x-ua-compatible" content="ie=edge"> |
| |
| |
| |
| |
| <meta name="lang:clipboard.copy" content="Copy to clipboard"> |
| |
| <meta name="lang:clipboard.copied" content="Copied to clipboard"> |
| |
| <meta name="lang:search.language" content="en"> |
| |
| <meta name="lang:search.pipeline.stopwords" content="True"> |
| |
| <meta name="lang:search.pipeline.trimmer" content="True"> |
| |
| <meta name="lang:search.result.none" content="No matching documents"> |
| |
| <meta name="lang:search.result.one" content="1 matching document"> |
| |
| <meta name="lang:search.result.other" content="# matching documents"> |
| |
| <meta name="lang:search.tokenizer" content="[\s\-]+"> |
| |
| <link rel="shortcut icon" href="../../assets/images/favicon.png"> |
| <meta name="generator" content="mkdocs-1.0.4, mkdocs-material-4.6.0"> |
| |
| |
| |
| <title>Digit Recognition on MNIST - MXNet.jl</title> |
| |
| |
| |
| <link rel="stylesheet" href="../../assets/stylesheets/application.1b62728e.css"> |
| |
| |
| |
| |
| <script src="../../assets/javascripts/modernizr.268332fc.js"></script> |
| |
| |
| |
| <link href="https://fonts.gstatic.com" rel="preconnect" crossorigin> |
| <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,400,400i,700%7CRoboto+Mono&display=fallback"> |
| <style>body,input{font-family:"Roboto","Helvetica Neue",Helvetica,Arial,sans-serif}code,kbd,pre{font-family:"Roboto Mono","Courier New",Courier,monospace}</style> |
| |
| |
| <link rel="stylesheet" href="../../assets/fonts/material-icons.css"> |
| |
| |
| <link rel="stylesheet" href="../../assets/Documenter.css"> |
| |
| |
| |
| |
| |
| </head> |
| |
| <body dir="ltr"> |
| |
| <svg class="md-svg"> |
| <defs> |
| |
| |
| <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> |
| |
| </defs> |
| </svg> |
| <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off"> |
| <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off"> |
| <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> |
| |
| <a href="#digit-recognition-on-mnist" tabindex="1" class="md-skip"> |
| Skip to content |
| </a> |
| |
| |
| <header class="md-header" data-md-component="header"> |
| <nav class="md-header-nav md-grid"> |
| <div class="md-flex"> |
| <div class="md-flex__cell md-flex__cell--shrink"> |
| <a href="../.." title="MXNet.jl" class="md-header-nav__button md-logo"> |
| |
| <i class="md-icon"></i> |
| |
| </a> |
| </div> |
| <div class="md-flex__cell md-flex__cell--shrink"> |
| <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> |
| </div> |
| <div class="md-flex__cell md-flex__cell--stretch"> |
| <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> |
| |
| <span class="md-header-nav__topic"> |
| MXNet.jl |
| </span> |
| <span class="md-header-nav__topic"> |
| |
| Digit Recognition on MNIST |
| |
| </span> |
| |
| </div> |
| </div> |
| <div class="md-flex__cell md-flex__cell--shrink"> |
| |
| <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> |
| |
| <div class="md-search" data-md-component="search" role="dialog"> |
| <label class="md-search__overlay" for="__search"></label> |
| <div class="md-search__inner" role="search"> |
| <form class="md-search__form" name="search"> |
| <input type="text" class="md-search__input" name="query" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="query" data-md-state="active"> |
| <label class="md-icon md-search__icon" for="__search"></label> |
| <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> |
|  |
| </button> |
| </form> |
| <div class="md-search__output"> |
| <div class="md-search__scrollwrap" data-md-scrollfix> |
| <div class="md-search-result" data-md-component="result"> |
| <div class="md-search-result__meta"> |
| Type to start searching |
| </div> |
| <ol class="md-search-result__list"></ol> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| |
| </div> |
| |
| <div class="md-flex__cell md-flex__cell--shrink"> |
| <div class="md-header-nav__source"> |
| |
| |
| |
| |
| |
| <a href="https://github.com/apache/incubator-mxnet/tree/master/julia#mxnet/" title="Go to repository" class="md-source" data-md-source="github"> |
| |
| <div class="md-source__icon"> |
| <svg viewBox="0 0 24 24" width="24" height="24"> |
| <use xlink:href="#__github" width="24" height="24"></use> |
| </svg> |
| </div> |
| |
| <div class="md-source__repository"> |
| GitHub |
| </div> |
| </a> |
| </div> |
| </div> |
| |
| </div> |
| </nav> |
| </header> |
| |
| <div class="md-container"> |
| |
| |
| |
| |
| <main class="md-main" role="main"> |
| <div class="md-main__inner md-grid" data-md-component="container"> |
| |
| |
| <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> |
| <div class="md-sidebar__scrollwrap"> |
| <div class="md-sidebar__inner"> |
| <nav class="md-nav md-nav--primary" data-md-level="0"> |
| <label class="md-nav__title md-nav__title--site" for="__drawer"> |
| <a href="../.." title="MXNet.jl" class="md-nav__button md-logo"> |
| |
| <i class="md-icon"></i> |
| |
| </a> |
| MXNet.jl |
| </label> |
| |
| <div class="md-nav__source"> |
| |
| |
| |
| |
| |
| <a href="https://github.com/apache/incubator-mxnet/tree/master/julia#mxnet/" title="Go to repository" class="md-source" data-md-source="github"> |
| |
| <div class="md-source__icon"> |
| <svg viewBox="0 0 24 24" width="24" height="24"> |
| <use xlink:href="#__github" width="24" height="24"></use> |
| </svg> |
| </div> |
| |
| <div class="md-source__repository"> |
| GitHub |
| </div> |
| </a> |
| </div> |
| |
| <ul class="md-nav__list" data-md-scrollfix> |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../.." title="Home" class="md-nav__link"> |
| Home |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item md-nav__item--active md-nav__item--nested"> |
| |
| <input class="md-toggle md-nav__toggle" data-md-toggle="nav-2" type="checkbox" id="nav-2" checked> |
| |
| <label class="md-nav__link" for="nav-2"> |
| Tutorial |
| </label> |
| <nav class="md-nav" data-md-component="collapsible" data-md-level="1"> |
| <label class="md-nav__title" for="nav-2"> |
| Tutorial |
| </label> |
| <ul class="md-nav__list" data-md-scrollfix> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item md-nav__item--active"> |
| |
| <input class="md-toggle md-nav__toggle" data-md-toggle="toc" type="checkbox" id="__toc"> |
| |
| |
| |
| |
| <label class="md-nav__link md-nav__link--active" for="__toc"> |
| Digit Recognition on MNIST |
| </label> |
| |
| <a href="./" title="Digit Recognition on MNIST" class="md-nav__link md-nav__link--active"> |
| Digit Recognition on MNIST |
| </a> |
| |
| |
| <nav class="md-nav md-nav--secondary"> |
| |
| |
| |
| |
| |
| <label class="md-nav__title" for="__toc">Table of contents</label> |
| <ul class="md-nav__list" data-md-scrollfix> |
| |
| <li class="md-nav__item"> |
| <a href="#simple-3-layer-mlp" class="md-nav__link"> |
| Simple 3-layer MLP |
| </a> |
| |
| </li> |
| |
| <li class="md-nav__item"> |
| <a href="#convolutional-neural-networks" class="md-nav__link"> |
| Convolutional Neural Networks |
| </a> |
| |
| </li> |
| |
| <li class="md-nav__item"> |
| <a href="#predicting-with-a-trained-model" class="md-nav__link"> |
| Predicting with a trained model |
| </a> |
| |
| </li> |
| |
| |
| |
| |
| |
| </ul> |
| |
| </nav> |
| |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../char-lstm/" title="Generating Random Sentence with LSTM RNN" class="md-nav__link"> |
| Generating Random Sentence with LSTM RNN |
| </a> |
| </li> |
| |
| |
| </ul> |
| </nav> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item md-nav__item--nested"> |
| |
| <input class="md-toggle md-nav__toggle" data-md-toggle="nav-3" type="checkbox" id="nav-3"> |
| |
| <label class="md-nav__link" for="nav-3"> |
| User Guide |
| </label> |
| <nav class="md-nav" data-md-component="collapsible" data-md-level="1"> |
| <label class="md-nav__title" for="nav-3"> |
| User Guide |
| </label> |
| <ul class="md-nav__list" data-md-scrollfix> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../user-guide/install/" title="Installation Guide" class="md-nav__link"> |
| Installation Guide |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../user-guide/overview/" title="Overview" class="md-nav__link"> |
| Overview |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../user-guide/faq/" title="FAQ" class="md-nav__link"> |
| FAQ |
| </a> |
| </li> |
| |
| |
| </ul> |
| </nav> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item md-nav__item--nested"> |
| |
| <input class="md-toggle md-nav__toggle" data-md-toggle="nav-4" type="checkbox" id="nav-4"> |
| |
| <label class="md-nav__link" for="nav-4"> |
| API Documentation |
| </label> |
| <nav class="md-nav" data-md-component="collapsible" data-md-level="1"> |
| <label class="md-nav__title" for="nav-4"> |
| API Documentation |
| </label> |
| <ul class="md-nav__list" data-md-scrollfix> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/context/" title="Context" class="md-nav__link"> |
| Context |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/model/" title="Models" class="md-nav__link"> |
| Models |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/initializer/" title="Initializers" class="md-nav__link"> |
| Initializers |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/optimizer/" title="Optimizers" class="md-nav__link"> |
| Optimizers |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/callback/" title="Callbacks in training" class="md-nav__link"> |
| Callbacks in training |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/metric/" title="Evaluation Metrics" class="md-nav__link"> |
| Evaluation Metrics |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/io/" title="Data Providers" class="md-nav__link"> |
| Data Providers |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/ndarray/" title="NDArray API" class="md-nav__link"> |
| NDArray API |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/symbolic-node/" title="Symbolic API" class="md-nav__link"> |
| Symbolic API |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/nn-factory/" title="Neural Networks Factory" class="md-nav__link"> |
| Neural Networks Factory |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/executor/" title="Executor" class="md-nav__link"> |
| Executor |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/kvstore/" title="Key-Value Store" class="md-nav__link"> |
| Key-Value Store |
| </a> |
| </li> |
| |
| |
| |
| |
| |
| |
| |
| <li class="md-nav__item"> |
| <a href="../../api/visualize/" title="Network Visualization" class="md-nav__link"> |
| Network Visualization |
| </a> |
| </li> |
| |
| |
| </ul> |
| </nav> |
| </li> |
| |
| |
| </ul> |
| </nav> |
| </div> |
| </div> |
| </div> |
| |
| |
| <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> |
| <div class="md-sidebar__scrollwrap"> |
| <div class="md-sidebar__inner"> |
| |
| <nav class="md-nav md-nav--secondary"> |
| |
| |
| |
| |
| |
| <label class="md-nav__title" for="__toc">Table of contents</label> |
| <ul class="md-nav__list" data-md-scrollfix> |
| |
| <li class="md-nav__item"> |
| <a href="#simple-3-layer-mlp" class="md-nav__link"> |
| Simple 3-layer MLP |
| </a> |
| |
| </li> |
| |
| <li class="md-nav__item"> |
| <a href="#convolutional-neural-networks" class="md-nav__link"> |
| Convolutional Neural Networks |
| </a> |
| |
| </li> |
| |
| <li class="md-nav__item"> |
| <a href="#predicting-with-a-trained-model" class="md-nav__link"> |
| Predicting with a trained model |
| </a> |
| |
| </li> |
| |
| |
| |
| |
| |
| </ul> |
| |
| </nav> |
| </div> |
| </div> |
| </div> |
| |
| |
| <div class="md-content"> |
| <article class="md-content__inner md-typeset"> |
| |
| |
| <a href="https://github.com/apache/incubator-mxnet/tree/master/edit/master/docs/tutorial/mnist.md" title="Edit this page" class="md-icon md-content__icon"></a> |
| |
| |
| <!–- Licensed to the Apache Software Foundation (ASF) under one –> <!–- or more contributor license agreements. See the NOTICE file –> <!–- distributed with this work for additional information –> <!–- regarding copyright ownership. The ASF licenses this file –> <!–- to you under the Apache License, Version 2.0 (the –> <!–- "License"); you may not use this file except in compliance –> <!–- with the License. You may obtain a copy of the License at –> |
| |
| <!–- http://www.apache.org/licenses/LICENSE-2.0 –> |
| |
| <!–- Unless required by applicable law or agreed to in writing, –> <!–- software distributed under the License is distributed on an –> <!–- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY –> <!–- KIND, either express or implied. See the License for the –> <!–- specific language governing permissions and limitations –> <!–- under the License. –> |
| |
| <p><a id='Digit-Recognition-on-MNIST-1'></a></p> |
| <h1 id="digit-recognition-on-mnist">Digit Recognition on MNIST</h1> |
| <p>In this tutorial, we will work through examples of training a simple multi-layer perceptron and then a convolutional neural network (the LeNet architecture) on the <a href="http://yann.lecun.com/exdb/mnist/">MNIST handwritten digit dataset</a>. The code for this tutorial could be found in <a href="https://github.com/apache/incubator-mxnet/tree/master/julia/examples/mnist">examples/mnist</a>. There are also two Jupyter notebooks that expand a little more on the <a href="https://github.com/ultradian/julia_notebooks/blob/master/mxnet/mnistMLP.ipynb">MLP</a> and the <a href="https://github.com/ultradian/julia_notebooks/blob/master/mxnet/mnistLenet.ipynb">LeNet</a>, using the more general <code>ArrayDataProvider</code>. </p> |
| <p><a id='Simple-3-layer-MLP-1'></a></p> |
| <h2 id="simple-3-layer-mlp">Simple 3-layer MLP</h2> |
| <p>This is a tiny 3-layer MLP that could be easily trained on CPU. The script starts with</p> |
| <pre><code class="julia">using MXNet |
| </code></pre> |
| |
| <p>to load the <code>MXNet</code> module. Then we are ready to define the network architecture via the <a href="/api/julia/docs/api/user-guide/overview/">symbolic API</a>. We start with a placeholder <code>data</code> symbol,</p> |
| <pre><code class="julia">data = mx.Variable(:data) |
| </code></pre> |
| |
| <p>and then cascading fully-connected layers and activation functions:</p> |
| <pre><code class="julia">fc1 = mx.FullyConnected(data, name=:fc1, num_hidden=128) |
| act1 = mx.Activation(fc1, name=:relu1, act_type=:relu) |
| fc2 = mx.FullyConnected(act1, name=:fc2, num_hidden=64) |
| act2 = mx.Activation(fc2, name=:relu2, act_type=:relu) |
| fc3 = mx.FullyConnected(act2, name=:fc3, num_hidden=10) |
| </code></pre> |
| |
| <p>Note each composition we take the previous symbol as the first argument, forming a feedforward chain. The architecture looks like</p> |
| <pre><code>Input --> 128 units (ReLU) --> 64 units (ReLU) --> 10 units |
| </code></pre> |
| |
| <p>where the last 10 units correspond to the 10 output classes (digits 0,...,9). We then add a final <code>SoftmaxOutput</code> operation to turn the 10-dimensional prediction to proper probability values for the 10 classes:</p> |
| <pre><code class="julia">mlp = mx.SoftmaxOutput(fc3, name=:softmax) |
| </code></pre> |
| |
| <p>As we can see, the MLP is just a chain of layers. For this case, we can also use the <code>mx.chain</code> macro. The same architecture above can be defined as</p> |
| <pre><code class="julia">mlp = @mx.chain mx.Variable(:data) => |
| mx.FullyConnected(name=:fc1, num_hidden=128) => |
| mx.Activation(name=:relu1, act_type=:relu) => |
| mx.FullyConnected(name=:fc2, num_hidden=64) => |
| mx.Activation(name=:relu2, act_type=:relu) => |
| mx.FullyConnected(name=:fc3, num_hidden=10) => |
| mx.SoftmaxOutput(name=:softmax) |
| </code></pre> |
| |
| <p>After defining the architecture, we are ready to load the MNIST data. MXNet.jl provide built-in data providers for the MNIST dataset, which could automatically download the dataset into <code>Pkg.dir("MXNet")/data/mnist</code> if necessary. We wrap the code to construct the data provider into <code>mnist-data.jl</code> so that it could be shared by both the MLP example and the LeNet ConvNets example.</p> |
| <pre><code class="julia">batch_size = 100 |
| include("mnist-data.jl") |
| train_provider, eval_provider = get_mnist_providers(batch_size) |
| </code></pre> |
| |
| <p>If you need to write your own data providers for customized data format, please refer to <a href="../../api/io/#MXNet.mx.AbstractDataProvider"><code>mx.AbstractDataProvider</code></a>.</p> |
| <p>Given the architecture and data, we can instantiate an <em>model</em> to do the actual training. <code>mx.FeedForward</code> is the built-in model that is suitable for most feed-forward architectures. When constructing the model, we also specify the <em>context</em> on which the computation should be carried out. Because this is a really tiny MLP, we will just run on a single CPU device.</p> |
| <pre><code class="julia">model = mx.FeedForward(mlp, context=mx.cpu()) |
| </code></pre> |
| |
| <p>You can use a <code>mx.gpu()</code> or if a list of devices (e.g. <code>[mx.gpu(0), mx.gpu(1)]</code>) is provided, data-parallelization will be used automatically. But for this tiny example, using a GPU device might not help.</p> |
| <p>The last thing we need to specify is the optimization algorithm (a.k.a. <em>optimizer</em>) to use. We use the basic SGD with a fixed learning rate 0.1 , momentum 0.9 and weight decay 0.00001:</p> |
| <pre><code class="julia">optimizer = mx.SGD(η=0.1, μ=0.9, λ=0.00001) |
| </code></pre> |
| |
| <p>Now we can do the training. Here the <code>n_epoch</code> parameter specifies that we want to train for 20 epochs. We also supply a <code>eval_data</code> to monitor validation accuracy on the validation set.</p> |
| <pre><code class="julia">mx.fit(model, optimizer, train_provider, n_epoch=20, eval_data=eval_provider) |
| </code></pre> |
| |
| <p>Here is a sample output</p> |
| <pre><code>INFO: Start training on [CPU0] |
| INFO: Initializing parameters... |
| INFO: Creating KVStore... |
| INFO: == Epoch 001 ========== |
| INFO: ## Training summary |
| INFO: :accuracy = 0.7554 |
| INFO: time = 1.3165 seconds |
| INFO: ## Validation summary |
| INFO: :accuracy = 0.9502 |
| ... |
| INFO: == Epoch 020 ========== |
| INFO: ## Training summary |
| INFO: :accuracy = 0.9949 |
| INFO: time = 0.9287 seconds |
| INFO: ## Validation summary |
| INFO: :accuracy = 0.9775 |
| </code></pre> |
| |
| <p><a id='Convolutional-Neural-Networks-1'></a></p> |
| <h2 id="convolutional-neural-networks">Convolutional Neural Networks</h2> |
| <p>In the second example, we show a slightly more complicated architecture that involves convolution and pooling. This architecture for the MNIST is usually called the [LeNet]_. The first part of the architecture is listed below:</p> |
| <pre><code class="julia"># input |
| data = mx.Variable(:data) |
| |
| # first conv |
| conv1 = @mx.chain mx.Convolution(data, kernel=(5,5), num_filter=20) => |
| mx.Activation(act_type=:tanh) => |
| mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2)) |
| |
| # second conv |
| conv2 = @mx.chain mx.Convolution(conv1, kernel=(5,5), num_filter=50) => |
| mx.Activation(act_type=:tanh) => |
| mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2)) |
| </code></pre> |
| |
| <p>We basically defined two convolution modules. Each convolution module is actually a chain of <code>Convolution</code>, <code>tanh</code> activation and then max <code>Pooling</code> operations.</p> |
| <p>Each sample in the MNIST dataset is a 28x28 single-channel grayscale image. In the tensor format used by <code>NDArray</code>, a batch of 100 samples is a tensor of shape <code>(28,28,1,100)</code>. The convolution and pooling operates in the spatial axis, so <code>kernel=(5,5)</code> indicate a square region of 5-width and 5-height. The rest of the architecture follows as:</p> |
| <pre><code class="julia"># first fully-connected |
| fc1 = @mx.chain mx.Flatten(conv2) => |
| mx.FullyConnected(num_hidden=500) => |
| mx.Activation(act_type=:tanh) |
| |
| # second fully-connected |
| fc2 = mx.FullyConnected(fc1, num_hidden=10) |
| |
| # softmax loss |
| lenet = mx.Softmax(fc2, name=:softmax) |
| </code></pre> |
| |
| <p>Note a fully-connected operator expects the input to be a matrix. However, the results from spatial convolution and pooling are 4D tensors. So we explicitly used a <code>Flatten</code> operator to flat the tensor, before connecting it to the <code>FullyConnected</code> operator.</p> |
| <p>The rest of the network is the same as the previous MLP example. As before, we can now load the MNIST dataset:</p> |
| <pre><code class="julia">batch_size = 100 |
| include("mnist-data.jl") |
| train_provider, eval_provider = get_mnist_providers(batch_size; flat=false) |
| </code></pre> |
| |
| <p>Note we specified <code>flat=false</code> to tell the data provider to provide 4D tensors instead of 2D matrices because the convolution operators needs correct spatial shape information. We then construct a feedforward model on GPU, and train it.</p> |
| <pre><code class="julia"># fit model |
| model = mx.FeedForward(lenet, context=mx.gpu()) |
| |
| # optimizer |
| optimizer = mx.SGD(η=0.05, μ=0.9, λ=0.00001) |
| |
| # fit parameters |
| mx.fit(model, optimizer, train_provider, n_epoch=20, eval_data=eval_provider) |
| </code></pre> |
| |
| <p>And here is a sample of running outputs:</p> |
| <pre><code>INFO: == Epoch 001 ========== |
| INFO: ## Training summary |
| INFO: :accuracy = 0.6750 |
| INFO: time = 4.9814 seconds |
| INFO: ## Validation summary |
| INFO: :accuracy = 0.9712 |
| ... |
| INFO: == Epoch 020 ========== |
| INFO: ## Training summary |
| INFO: :accuracy = 1.0000 |
| INFO: time = 4.0086 seconds |
| INFO: ## Validation summary |
| INFO: :accuracy = 0.9915 |
| </code></pre> |
| |
| <p><a id='Predicting-with-a-trained-model-1'></a></p> |
| <h2 id="predicting-with-a-trained-model">Predicting with a trained model</h2> |
| <p>Predicting with a trained model is very simple. By calling <code>mx.predict</code> with the model and a data provider, we get the model output as a Julia Array:</p> |
| <pre><code class="julia">probs = mx.predict(model, eval_provider) |
| </code></pre> |
| |
| <p>The following code shows a stupid way of getting all the labels from the data provider, and compute the prediction accuracy manually:</p> |
| <pre><code class="julia"># collect all labels from eval data |
| labels = reduce( |
| vcat, |
| copy(mx.get(eval_provider, batch, :softmax_label)) for batch ∈ eval_provider) |
| # labels are 0...9 |
| labels .= labels .+ 1 |
| |
| # Now we use compute the accuracy |
| pred = map(i -> argmax(probs[1:10, i]), 1:size(probs, 2)) |
| correct = sum(pred .== labels) |
| @printf "Accuracy on eval set: %.2f%%\n" 100correct/length(labels) |
| </code></pre> |
| |
| <p>Alternatively, when the dataset is huge, one can provide a callback to <code>mx.predict</code>, then the callback function will be invoked with the outputs of each mini-batch. The callback could, for example, write the data to disk for future inspection. In this case, no value is returned from <code>mx.predict</code>. See also predict.</p> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| </article> |
| </div> |
| </div> |
| </main> |
| |
| |
| <footer class="md-footer"> |
| |
| <div class="md-footer-nav"> |
| <nav class="md-footer-nav__inner md-grid"> |
| |
| <a href="../.." title="Home" class="md-flex md-footer-nav__link md-footer-nav__link--prev" rel="prev"> |
| <div class="md-flex__cell md-flex__cell--shrink"> |
| <i class="md-icon md-icon--arrow-back md-footer-nav__button"></i> |
| </div> |
| <div class="md-flex__cell md-flex__cell--stretch md-footer-nav__title"> |
| <span class="md-flex__ellipsis"> |
| <span class="md-footer-nav__direction"> |
| Previous |
| </span> |
| Home |
| </span> |
| </div> |
| </a> |
| |
| |
| <a href="../char-lstm/" title="Generating Random Sentence with LSTM RNN" class="md-flex md-footer-nav__link md-footer-nav__link--next" rel="next"> |
| <div class="md-flex__cell md-flex__cell--stretch md-footer-nav__title"> |
| <span class="md-flex__ellipsis"> |
| <span class="md-footer-nav__direction"> |
| Next |
| </span> |
| Generating Random Sentence with LSTM RNN |
| </span> |
| </div> |
| <div class="md-flex__cell md-flex__cell--shrink"> |
| <i class="md-icon md-icon--arrow-forward md-footer-nav__button"></i> |
| </div> |
| </a> |
| |
| </nav> |
| </div> |
| |
| <div class="md-footer-meta md-typeset"> |
| <div class="md-footer-meta__inner md-grid"> |
| <div class="md-footer-copyright"> |
| |
| powered by |
| <a href="https://www.mkdocs.org">MkDocs</a> |
| and |
| <a href="https://squidfunk.github.io/mkdocs-material/"> |
| Material for MkDocs</a> |
| </div> |
| |
| </div> |
| </div> |
| </footer> |
| |
| </div> |
| |
| <script src="../../assets/javascripts/application.808e90bb.js"></script> |
| |
| <script>app.initialize({version:"1.0.4",url:{base:"../.."}})</script> |
| |
| <script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML"></script> |
| |
| <script src="../../assets/mathjaxhelper.js"></script> |
| |
| |
| </body> |
| </html> |