| <!DOCTYPE html> |
| |
| <!--- |
| Licensed to the Apache Software Foundation (ASF) under one |
| or more contributor license agreements. See the NOTICE file |
| distributed with this work for additional information |
| regarding copyright ownership. The ASF licenses this file |
| to you under the Apache License, Version 2.0 (the |
| "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at |
| http://www.apache.org/licenses/LICENSE-2.0 |
| Unless required by applicable law or agreed to in writing, |
| software distributed under the License is distributed on an |
| "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| KIND, either express or implied. See the License for the |
| specific language governing permissions and limitations |
| under the License. |
| --> |
| |
| <html lang=" en"><head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge"> |
| <meta name="viewport" content="width=device-width, initial-scale=1"> |
| <link href="/versions/1.9.1/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 --> |
| <title>MNIST Competition | Apache MXNet</title> |
| <meta name="generator" content="Jekyll v3.8.6" /> |
| <meta property="og:title" content="MNIST Competition" /> |
| <meta property="og:locale" content="en_US" /> |
| <meta name="description" content="A flexible and efficient library for deep learning." /> |
| <meta property="og:description" content="A flexible and efficient library for deep learning." /> |
| <link rel="canonical" href="https://mxnet.apache.org/versions/1.9.1/api/r/docs/tutorials/mnist_competition" /> |
| <meta property="og:url" content="https://mxnet.apache.org/versions/1.9.1/api/r/docs/tutorials/mnist_competition" /> |
| <meta property="og:site_name" content="Apache MXNet" /> |
| <script type="application/ld+json"> |
| {"headline":"MNIST Competition","description":"A flexible and efficient library for deep learning.","url":"https://mxnet.apache.org/versions/1.9.1/api/r/docs/tutorials/mnist_competition","@type":"WebPage","@context":"https://schema.org"}</script> |
| <!-- End Jekyll SEO tag --> |
| <link rel="stylesheet" href="/versions/1.9.1/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/1.9.1/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.1/feed.xml" title="Apache MXNet" /><!-- Matomo --> |
| <script> |
| var _paq = window._paq = window._paq || []; |
| /* tracker methods like "setCustomDimension" should be called before "trackPageView" */ |
| /* We explicitly disable cookie tracking to avoid privacy issues */ |
| _paq.push(['disableCookies']); |
| _paq.push(['trackPageView']); |
| _paq.push(['enableLinkTracking']); |
| (function() { |
| var u="https://analytics.apache.org/"; |
| _paq.push(['setTrackerUrl', u+'matomo.php']); |
| _paq.push(['setSiteId', '23']); |
| var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0]; |
| g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s); |
| })(); |
| </script> |
| <!-- End Matomo Code --> |
| |
| <script src="/versions/1.9.1/assets/js/jquery-3.3.1.min.js"></script> |
| <script src="/versions/1.9.1/assets/js/docsearch.min.js"></script><script src="/versions/1.9.1/assets/js/globalSearch.js" defer></script> |
| <script src="/versions/1.9.1/assets/js/clipboard.js" defer></script> |
| <script src="/versions/1.9.1/assets/js/copycode.js" defer></script></head> |
| <body><header class="site-header" role="banner"> |
| |
| <script> |
| $(document).ready(function () { |
| |
| // HEADER OPACITY LOGIC |
| |
| function opacity_header() { |
| var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")" |
| $('.site-header').css("background-color", value) |
| } |
| |
| $(window).scroll(function () { |
| opacity_header() |
| }) |
| opacity_header(); |
| |
| // MENU SELECTOR LOGIC |
| $('.page-link').each( function () { |
| if (window.location.href.includes(this.href)) { |
| $(this).addClass("page-current"); |
| } |
| }); |
| }) |
| </script> |
| <div class="wrapper"> |
| <a class="site-title" rel="author" href="/versions/1.9.1/"><img |
| src="/versions/1.9.1/assets/img/mxnet_logo.png" class="site-header-logo"></a> |
| <nav class="site-nav"> |
| <input type="checkbox" id="nav-trigger" class="nav-trigger"/> |
| <label for="nav-trigger"> |
| <span class="menu-icon"> |
| <svg viewBox="0 0 18 15" width="18px" height="15px"> |
| <path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/> |
| </svg> |
| </span> |
| </label> |
| <div class="gs-search-border"> |
| <div id="gs-search-icon"></div> |
| <form id="global-search-form"> |
| <input id="global-search" type="text" title="Search" placeholder="Search" /> |
| <div id="global-search-dropdown-container"> |
| <button class="gs-current-version btn" type="button" data-toggle="dropdown"> |
| <span id="gs-current-version-label">1.9.1</span> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown"> |
| |
| |
| <li class="gs-opt gs-versions">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions active">1.9.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| <span id="global-search-close">x</span> |
| </form> |
| </div> |
| <div class="trigger"> |
| <div id="global-search-mobile-border"> |
| <div id="gs-search-icon-mobile"></div> |
| <input id="global-search-mobile" placeholder="Search..." type="text"/> |
| <div id="global-search-dropdown-container-mobile"> |
| <button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown"> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown-mobile"> |
| |
| |
| <li class="gs-opt gs-versions">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions active">1.9.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.8.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| </div> |
| <a class="page-link" href="/versions/1.9.1/get_started">Get Started</a> |
| <a class="page-link" href="/versions/1.9.1/features">Features</a> |
| <a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a> |
| <a class="page-link" href="/versions/1.9.1/api">Docs & Tutorials</a> |
| <a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a> |
| <a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a> |
| <div class="dropdown" style="min-width:100px"> |
| <span class="dropdown-header">Apache |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content" style="min-width:250px"> |
| <a href="https://www.apache.org/foundation/">Apache Software Foundation</a> |
| <a href="https://incubator.apache.org/">Apache Incubator</a> |
| <a href="https://www.apache.org/licenses/">License</a> |
| <a href="/versions/1.9.1/api/faq/security.html">Security</a> |
| <a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a> |
| <a href="https://www.apache.org/events/current-event">Events</a> |
| <a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a> |
| <a href="https://www.apache.org/foundation/thanks.html">Thanks</a> |
| </div> |
| </div> |
| <div class="dropdown"> |
| <span class="dropdown-header">1.9.1 |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content"> |
| <a href="/">master</a> |
| <a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a> |
| <a href="/versions/1.8.0/">1.8.0</a> |
| <a href="/versions/1.7.0/">1.7.0</a> |
| <a href="/versions/1.6.0/">1.6.0</a> |
| <a href="/versions/1.5.0/">1.5.0</a> |
| <a href="/versions/1.4.1/">1.4.1</a> |
| <a href="/versions/1.3.1/">1.3.1</a> |
| <a href="/versions/1.2.1/">1.2.1</a> |
| <a href="/versions/1.1.0/">1.1.0</a> |
| <a href="/versions/1.0.0/">1.0.0</a> |
| <a href="/versions/0.12.1/">0.12.1</a> |
| <a href="/versions/0.11.0/">0.11.0</a> |
| </div> |
| </div> |
| </div> |
| </nav> |
| </div> |
| </header> |
| <main class="page-content" aria-label="Content"> |
| <script> |
| |
| </script> |
| <article class="post"> |
| |
| <header class="post-header wrapper"> |
| <h1 class="post-title">MNIST Competition</h1> |
| <h3></h3></header> |
| |
| <div class="post-content"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3 docs-side-bar"> |
| |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| <div class="docs-card docs-side"> |
| <ul> |
| <div class="docs-action-btn"> |
| <a href="/versions/1.9.1/api/r"> <img src="/versions/1.9.1/assets/img/compass.svg" |
| class="docs-logo-docs">R Guide <span |
| class="span-accented">›</span></a> |
| </div> |
| <div class="docs-action-btn"> |
| <a href="/versions/1.9.1/api/r/docs/tutorials"> <img |
| src="/versions/1.9.1/assets/img/video-tutorial.svg" class="docs-logo-docs">R |
| Tutorials <span class="span-accented">›</span></a> |
| </div> |
| <div class="docs-action-btn"> |
| <a href="/versions/1.9.1/api/r/docs/api/R-package/build/mxnet-r-reference-manual.pdf"> <img src="/versions/1.9.1/assets/img/api.svg" |
| class="docs-logo-docs">R API Reference |
| <span class="span-accented">›</span></a> |
| </div> |
| |
| <!-- Let's show the list of tutorials --> |
| <br> |
| |
| <h3>Tutorials</h3> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/callback_function">Callback Function</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/char_rnn_model">Char RNN Model</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/classify_real_image_with_pretrained_model">Classify Images with a PreTrained Model</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/custom_iterator">Custom Iterator Tutorial</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/custom_loss_function">Custom Loss Function</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/five_minutes_neural_network">Five Minutes Neural Network</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/mnist_competition">MNIST Competition</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/multi_dim_lstm">LSTM Time Series</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/ndarray">NDArray</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| |
| <li><a href="/versions/1.9.1/api/r/docs/tutorials/symbol">NDArray</a></li> |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| |
| <!-- page-category --> |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| <!-- page --> |
| |
| </ul> |
| </div> |
| |
| |
| |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| <!-- page --> |
| </ul> |
| </div> |
| <div class="col-9"> |
| <!--- Licensed to the Apache Software Foundation (ASF) under one --> |
| |
| <!--- or more contributor license agreements. See the NOTICE file --> |
| |
| <!--- distributed with this work for additional information --> |
| |
| <!--- regarding copyright ownership. The ASF licenses this file --> |
| |
| <!--- to you under the Apache License, Version 2.0 (the --> |
| |
| <!--- "License"); you may not use this file except in compliance --> |
| |
| <!--- with the License. You may obtain a copy of the License at --> |
| |
| <!--- http://www.apache.org/licenses/LICENSE-2.0 --> |
| |
| <!--- Unless required by applicable law or agreed to in writing, --> |
| |
| <!--- software distributed under the License is distributed on an --> |
| |
| <!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> |
| |
| <!--- KIND, either express or implied. See the License for the --> |
| |
| <!--- specific language governing permissions and limitations --> |
| |
| <!--- under the License. --> |
| |
| <h1 id="handwritten-digits-classification-competition">Handwritten Digits Classification Competition</h1> |
| |
| <p><a href="http://yann.lecun.com/exdb/mnist/">MNIST</a> is a handwritten digits image data set created by Yann LeCun. Every digit is represented by a 28 x 28 pixel image. It's become a standard data set for testing classifiers on simple image input. A neural network is a strong model for image classification tasks. There's a <a href="https://www.kaggle.com/c/digit-recognizer">long-term hosted competition</a> on Kaggle using this data set. |
| This tutorial shows how to use <a href="https://github.com/apache/mxnet/tree/v1.x/R-package">MXNet</a> to compete in this challenge.</p> |
| |
| <h2 id="loading-the-data">Loading the Data</h2> |
| |
| <p>First, let's download the data from <a href="https://www.kaggle.com/c/digit-recognizer/data">Kaggle</a> and put it in the <code>data/</code> folder in your working directory.</p> |
| |
| <p>Now we can read it in R and convert it to matrices:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">require</span><span class="p">(</span><span class="n">mxnet</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Loading required package: mxnet |
| ## Loading required package: methods |
| </code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">train</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.csv</span><span class="p">(</span><span class="s1">'data/train.csv'</span><span class="p">,</span><span class="w"> </span><span class="n">header</span><span class="o">=</span><span class="kc">TRUE</span><span class="p">)</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read.csv</span><span class="p">(</span><span class="s1">'data/test.csv'</span><span class="p">,</span><span class="w"> </span><span class="n">header</span><span class="o">=</span><span class="kc">TRUE</span><span class="p">)</span><span class="w"> |
| </span><span class="n">train</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data.matrix</span><span class="p">(</span><span class="n">train</span><span class="p">)</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data.matrix</span><span class="p">(</span><span class="n">test</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">train.x</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train</span><span class="p">[,</span><span class="m">-1</span><span class="p">]</span><span class="w"> |
| </span><span class="n">train.y</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train</span><span class="p">[,</span><span class="m">1</span><span class="p">]</span><span class="w"> |
| </span></code></pre></div> |
| <p>Every image is represented as a single row in train/test. The greyscale of each image falls in the range [0, 255]. Linearly transform it into [0,1] by using the following command:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">train.x</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">t</span><span class="p">(</span><span class="n">train.x</span><span class="o">/</span><span class="m">255</span><span class="p">)</span><span class="w"> |
| </span><span class="n">test</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">t</span><span class="p">(</span><span class="n">test</span><span class="o">/</span><span class="m">255</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div> |
| <p>Transpose the input matrix to npixel x nexamples, which is the major format for columns accepted by MXNet (and the convention of R).</p> |
| |
| <p>In the label section, the number of each digit is fairly evenly distributed:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">table</span><span class="p">(</span><span class="n">train.y</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## train.y |
| ## 0 1 2 3 4 5 6 7 8 9 |
| ## 4132 4684 4177 4351 4072 3795 4137 4401 4063 4188 |
| </code></pre></div> |
| <h2 id="configuring-the-network">Configuring the Network</h2> |
| |
| <p>Now that we have the data, let's configure the structure of our network:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"data"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">fc1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"fc1"</span><span class="p">,</span><span class="w"> </span><span class="n">num_hidden</span><span class="o">=</span><span class="m">128</span><span class="p">)</span><span class="w"> |
| </span><span class="n">act1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">fc1</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"relu1"</span><span class="p">,</span><span class="w"> </span><span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">fc2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">act1</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"fc2"</span><span class="p">,</span><span class="w"> </span><span class="n">num_hidden</span><span class="o">=</span><span class="m">64</span><span class="p">)</span><span class="w"> |
| </span><span class="n">act2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">fc2</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"relu2"</span><span class="p">,</span><span class="w"> </span><span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">fc3</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">act2</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"fc3"</span><span class="p">,</span><span class="w"> </span><span class="n">num_hidden</span><span class="o">=</span><span class="m">10</span><span class="p">)</span><span class="w"> |
| </span><span class="n">softmax</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.SoftmaxOutput</span><span class="p">(</span><span class="n">fc3</span><span class="p">,</span><span class="w"> </span><span class="n">name</span><span class="o">=</span><span class="s2">"sm"</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div> |
| <ol> |
| <li>In <code>mxnet</code>, we use the data type <code>symbol</code> to configure the network. <code>data <- mx.symbol.Variable("data")</code> uses <code>data</code> to represent the input data, i.e., the input layer.</li> |
| <li>We set the first hidden layer with <code>fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)</code>. This layer has <code>data</code> as the input, its name, and the number of hidden neurons.</li> |
| <li>Activation is set with <code>act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")</code>. The activation function takes the output from the first hidden layer, <code>fc1</code>.</li> |
| <li>The second hidden layer takes the result from <code>act1</code> as input, with its name as "fc2" and the number of hidden neurons as 64.</li> |
| <li>The second activation is almost the same as <code>act1</code>, except we have a different input source and name.</li> |
| <li>This generates the output layer. Because there are only 10 digits, we set the number of neurons to 10.</li> |
| <li>Finally, we set the activation to softmax to get a probabilistic prediction.</li> |
| </ol> |
| |
| <h2 id="training">Training</h2> |
| |
| <p>We are almost ready for the training process. Before we start the computation, let's decide which device to use:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">devices</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.cpu</span><span class="p">()</span><span class="w"> |
| </span></code></pre></div> |
| <p>We assign CPU to <code>mxnet</code>. Now, you can run the following command to train the neural network! Note that <code>mx.set.seed</code> is the function that controls the random process in <code>mxnet</code>:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">mx.set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.model.FeedForward.create</span><span class="p">(</span><span class="n">softmax</span><span class="p">,</span><span class="w"> </span><span class="n">X</span><span class="o">=</span><span class="n">train.x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="o">=</span><span class="n">train.y</span><span class="p">,</span><span class="w"> |
| </span><span class="n">ctx</span><span class="o">=</span><span class="n">devices</span><span class="p">,</span><span class="w"> </span><span class="n">num.round</span><span class="o">=</span><span class="m">10</span><span class="p">,</span><span class="w"> </span><span class="n">array.batch.size</span><span class="o">=</span><span class="m">100</span><span class="p">,</span><span class="w"> |
| </span><span class="n">learning.rate</span><span class="o">=</span><span class="m">0.07</span><span class="p">,</span><span class="w"> </span><span class="n">momentum</span><span class="o">=</span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">eval.metric</span><span class="o">=</span><span class="n">mx.metric.accuracy</span><span class="p">,</span><span class="w"> |
| </span><span class="n">initializer</span><span class="o">=</span><span class="n">mx.init.uniform</span><span class="p">(</span><span class="m">0.07</span><span class="p">),</span><span class="w"> |
| </span><span class="n">epoch.end.callback</span><span class="o">=</span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="m">100</span><span class="p">))</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Start training with 1 devices |
| ## Batch [100] Train-accuracy=0.6563 |
| ## Batch [200] Train-accuracy=0.777999999999999 |
| ## Batch [300] Train-accuracy=0.827466666666665 |
| ## Batch [400] Train-accuracy=0.855499999999999 |
| ## [1] Train-accuracy=0.859832935560859 |
| ## Batch [100] Train-accuracy=0.9529 |
| ## Batch [200] Train-accuracy=0.953049999999999 |
| ## Batch [300] Train-accuracy=0.955866666666666 |
| ## Batch [400] Train-accuracy=0.957525000000001 |
| ## [2] Train-accuracy=0.958309523809525 |
| ## Batch [100] Train-accuracy=0.968 |
| ## Batch [200] Train-accuracy=0.9677 |
| ## Batch [300] Train-accuracy=0.9696 |
| ## Batch [400] Train-accuracy=0.970650000000002 |
| ## [3] Train-accuracy=0.970809523809526 |
| ## Batch [100] Train-accuracy=0.973 |
| ## Batch [200] Train-accuracy=0.974249999999999 |
| ## Batch [300] Train-accuracy=0.976 |
| ## Batch [400] Train-accuracy=0.977100000000003 |
| ## [4] Train-accuracy=0.977452380952384 |
| ## Batch [100] Train-accuracy=0.9834 |
| ## Batch [200] Train-accuracy=0.981949999999999 |
| ## Batch [300] Train-accuracy=0.981900000000001 |
| ## Batch [400] Train-accuracy=0.982600000000003 |
| ## [5] Train-accuracy=0.983000000000003 |
| ## Batch [100] Train-accuracy=0.983399999999999 |
| ## Batch [200] Train-accuracy=0.98405 |
| ## Batch [300] Train-accuracy=0.985000000000001 |
| ## Batch [400] Train-accuracy=0.985725000000003 |
| ## [6] Train-accuracy=0.985952380952384 |
| ## Batch [100] Train-accuracy=0.988999999999999 |
| ## Batch [200] Train-accuracy=0.9876 |
| ## Batch [300] Train-accuracy=0.988100000000001 |
| ## Batch [400] Train-accuracy=0.988750000000003 |
| ## [7] Train-accuracy=0.988880952380955 |
| ## Batch [100] Train-accuracy=0.991999999999999 |
| ## Batch [200] Train-accuracy=0.9912 |
| ## Batch [300] Train-accuracy=0.990066666666668 |
| ## Batch [400] Train-accuracy=0.990275000000003 |
| ## [8] Train-accuracy=0.990452380952384 |
| ## Batch [100] Train-accuracy=0.9937 |
| ## Batch [200] Train-accuracy=0.99235 |
| ## Batch [300] Train-accuracy=0.991966666666668 |
| ## Batch [400] Train-accuracy=0.991425000000003 |
| ## [9] Train-accuracy=0.991500000000003 |
| ## Batch [100] Train-accuracy=0.9942 |
| ## Batch [200] Train-accuracy=0.99245 |
| ## Batch [300] Train-accuracy=0.992433333333334 |
| ## Batch [400] Train-accuracy=0.992275000000002 |
| ## [10] Train-accuracy=0.992380952380955 |
| </code></pre></div> |
| <h2 id="making-a-prediction-and-submitting-to-the-competition">Making a Prediction and Submitting to the Competition</h2> |
| |
| <p>To make a prediction, type:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">preds</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test</span><span class="p">)</span><span class="w"> |
| </span><span class="nf">dim</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## [1] 10 28000 |
| </code></pre></div> |
| <p>It is a matrix with 28000 rows and 10 cols, containing the desired classification probabilities from the output layer. To extract the maximum label for each row, use <code>max.col</code>:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">pred.label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">max.col</span><span class="p">(</span><span class="n">t</span><span class="p">(</span><span class="n">preds</span><span class="p">))</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="w"> |
| </span><span class="n">table</span><span class="p">(</span><span class="n">pred.label</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## pred.label |
| ## 0 1 2 3 4 5 6 7 8 9 |
| ## 2818 3195 2744 2767 2683 2596 2798 2790 2784 2825 |
| </code></pre></div> |
| <p>With a little extra effort to modify the .csv format, our submission is ready for the competition!</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">submission</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data.frame</span><span class="p">(</span><span class="n">ImageId</span><span class="o">=</span><span class="m">1</span><span class="o">:</span><span class="n">ncol</span><span class="p">(</span><span class="n">test</span><span class="p">),</span><span class="w"> </span><span class="n">Label</span><span class="o">=</span><span class="n">pred.label</span><span class="p">)</span><span class="w"> |
| </span><span class="n">write.csv</span><span class="p">(</span><span class="n">submission</span><span class="p">,</span><span class="w"> </span><span class="n">file</span><span class="o">=</span><span class="s1">'submission.csv'</span><span class="p">,</span><span class="w"> </span><span class="n">row.names</span><span class="o">=</span><span class="kc">FALSE</span><span class="p">,</span><span class="w"> </span><span class="n">quote</span><span class="o">=</span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div> |
| <h2 id="lenet">LeNet</h2> |
| |
| <p>Now let's use a new network structure: <a href="http://yann.lecun.com/exdb/lenet/">LeNet</a>. It has been proposed by Yann LeCun for recognizing handwritten digits. We'll demonstrate how to construct and train a LeNet in <code>mxnet</code>.</p> |
| |
| <p>First, we construct the network:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="c1"># input</span><span class="w"> |
| </span><span class="n">data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span><span class="w"> |
| </span><span class="c1"># first conv</span><span class="w"> |
| </span><span class="n">conv1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Convolution</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">,</span><span class="w"> </span><span class="n">kernel</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">5</span><span class="p">,</span><span class="m">5</span><span class="p">),</span><span class="w"> </span><span class="n">num_filter</span><span class="o">=</span><span class="m">20</span><span class="p">)</span><span class="w"> |
| </span><span class="n">tanh1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">conv1</span><span class="p">,</span><span class="w"> </span><span class="n">act_type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">pool1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Pooling</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">tanh1</span><span class="p">,</span><span class="w"> </span><span class="n">pool_type</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span><span class="w"> |
| </span><span class="n">kernel</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">),</span><span class="w"> </span><span class="n">stride</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">))</span><span class="w"> |
| </span><span class="c1"># second conv</span><span class="w"> |
| </span><span class="n">conv2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Convolution</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">pool1</span><span class="p">,</span><span class="w"> </span><span class="n">kernel</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">5</span><span class="p">,</span><span class="m">5</span><span class="p">),</span><span class="w"> </span><span class="n">num_filter</span><span class="o">=</span><span class="m">50</span><span class="p">)</span><span class="w"> |
| </span><span class="n">tanh2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">conv2</span><span class="p">,</span><span class="w"> </span><span class="n">act_type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">pool2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Pooling</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">tanh2</span><span class="p">,</span><span class="w"> </span><span class="n">pool_type</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span><span class="w"> |
| </span><span class="n">kernel</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">),</span><span class="w"> </span><span class="n">stride</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">))</span><span class="w"> |
| </span><span class="c1"># first fullc</span><span class="w"> |
| </span><span class="n">flatten</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Flatten</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">pool2</span><span class="p">)</span><span class="w"> |
| </span><span class="n">fc1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">flatten</span><span class="p">,</span><span class="w"> </span><span class="n">num_hidden</span><span class="o">=</span><span class="m">500</span><span class="p">)</span><span class="w"> |
| </span><span class="n">tanh3</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Activation</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc1</span><span class="p">,</span><span class="w"> </span><span class="n">act_type</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span><span class="w"> |
| </span><span class="c1"># second fullc</span><span class="w"> |
| </span><span class="n">fc2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">tanh3</span><span class="p">,</span><span class="w"> </span><span class="n">num_hidden</span><span class="o">=</span><span class="m">10</span><span class="p">)</span><span class="w"> |
| </span><span class="c1"># loss</span><span class="w"> |
| </span><span class="n">lenet</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.SoftmaxOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">fc2</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div> |
| <p>Then let's reshape the matrices into arrays:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="n">train.array</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train.x</span><span class="w"> |
| </span><span class="nf">dim</span><span class="p">(</span><span class="n">train.array</span><span class="p">)</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">28</span><span class="p">,</span><span class="w"> </span><span class="m">28</span><span class="p">,</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">ncol</span><span class="p">(</span><span class="n">train.x</span><span class="p">))</span><span class="w"> |
| </span><span class="n">test.array</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">test</span><span class="w"> |
| </span><span class="nf">dim</span><span class="p">(</span><span class="n">test.array</span><span class="p">)</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">28</span><span class="p">,</span><span class="w"> </span><span class="m">28</span><span class="p">,</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">ncol</span><span class="p">(</span><span class="n">test</span><span class="p">))</span><span class="w"> |
| </span></code></pre></div> |
| <p>We want to compare training speed on different devices, so define the devices:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="n">n.gpu</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">1</span><span class="w"> |
| </span><span class="n">device.cpu</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.cpu</span><span class="p">()</span><span class="w"> |
| </span><span class="n">device.gpu</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">lapply</span><span class="p">(</span><span class="m">0</span><span class="o">:</span><span class="p">(</span><span class="n">n.gpu</span><span class="m">-1</span><span class="p">),</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">mx.gpu</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="w"> |
| </span><span class="p">})</span><span class="w"> |
| </span></code></pre></div> |
| <p>We can pass a list of devices to ask MXNet to train on multiple GPUs (you can do this for CPUs, |
| but because internal computation of CPUs is already multi-threaded, there is less gain than with using GPUs).</p> |
| |
| <p>Start by training on the CPU first. Because this takes a bit time, we run it for just one iteration.</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">mx.set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w"> |
| </span><span class="n">tic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">proc.time</span><span class="p">()</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.model.FeedForward.create</span><span class="p">(</span><span class="n">lenet</span><span class="p">,</span><span class="w"> </span><span class="n">X</span><span class="o">=</span><span class="n">train.array</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="o">=</span><span class="n">train.y</span><span class="p">,</span><span class="w"> |
| </span><span class="n">ctx</span><span class="o">=</span><span class="n">device.cpu</span><span class="p">,</span><span class="w"> </span><span class="n">num.round</span><span class="o">=</span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">array.batch.size</span><span class="o">=</span><span class="m">100</span><span class="p">,</span><span class="w"> |
| </span><span class="n">learning.rate</span><span class="o">=</span><span class="m">0.05</span><span class="p">,</span><span class="w"> </span><span class="n">momentum</span><span class="o">=</span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">wd</span><span class="o">=</span><span class="m">0.00001</span><span class="p">,</span><span class="w"> |
| </span><span class="n">eval.metric</span><span class="o">=</span><span class="n">mx.metric.accuracy</span><span class="p">,</span><span class="w"> |
| </span><span class="n">epoch.end.callback</span><span class="o">=</span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="m">100</span><span class="p">))</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Start training with 1 devices |
| ## Batch [100] Train-accuracy=0.1066 |
| ## Batch [200] Train-accuracy=0.16495 |
| ## Batch [300] Train-accuracy=0.401766666666667 |
| ## Batch [400] Train-accuracy=0.537675 |
| ## [1] Train-accuracy=0.557136038186157 |
| </code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">print</span><span class="p">(</span><span class="nf">proc.time</span><span class="p">()</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">tic</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## user system elapsed |
| ## 130.030 204.976 83.821 |
| </code></pre></div> |
| <p>Train on a GPU:</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">mx.set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w"> |
| </span><span class="n">tic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">proc.time</span><span class="p">()</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.model.FeedForward.create</span><span class="p">(</span><span class="n">lenet</span><span class="p">,</span><span class="w"> </span><span class="n">X</span><span class="o">=</span><span class="n">train.array</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="o">=</span><span class="n">train.y</span><span class="p">,</span><span class="w"> |
| </span><span class="n">ctx</span><span class="o">=</span><span class="n">device.gpu</span><span class="p">,</span><span class="w"> </span><span class="n">num.round</span><span class="o">=</span><span class="m">5</span><span class="p">,</span><span class="w"> </span><span class="n">array.batch.size</span><span class="o">=</span><span class="m">100</span><span class="p">,</span><span class="w"> |
| </span><span class="n">learning.rate</span><span class="o">=</span><span class="m">0.05</span><span class="p">,</span><span class="w"> </span><span class="n">momentum</span><span class="o">=</span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">wd</span><span class="o">=</span><span class="m">0.00001</span><span class="p">,</span><span class="w"> |
| </span><span class="n">eval.metric</span><span class="o">=</span><span class="n">mx.metric.accuracy</span><span class="p">,</span><span class="w"> |
| </span><span class="n">epoch.end.callback</span><span class="o">=</span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="m">100</span><span class="p">))</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Start training with 1 devices |
| ## Batch [100] Train-accuracy=0.1066 |
| ## Batch [200] Train-accuracy=0.1596 |
| ## Batch [300] Train-accuracy=0.3983 |
| ## Batch [400] Train-accuracy=0.533975 |
| ## [1] Train-accuracy=0.553532219570405 |
| ## Batch [100] Train-accuracy=0.958 |
| ## Batch [200] Train-accuracy=0.96155 |
| ## Batch [300] Train-accuracy=0.966100000000001 |
| ## Batch [400] Train-accuracy=0.968550000000003 |
| ## [2] Train-accuracy=0.969071428571432 |
| ## Batch [100] Train-accuracy=0.977 |
| ## Batch [200] Train-accuracy=0.97715 |
| ## Batch [300] Train-accuracy=0.979566666666668 |
| ## Batch [400] Train-accuracy=0.980900000000003 |
| ## [3] Train-accuracy=0.981309523809527 |
| ## Batch [100] Train-accuracy=0.9853 |
| ## Batch [200] Train-accuracy=0.985899999999999 |
| ## Batch [300] Train-accuracy=0.986966666666668 |
| ## Batch [400] Train-accuracy=0.988150000000002 |
| ## [4] Train-accuracy=0.988452380952384 |
| ## Batch [100] Train-accuracy=0.990199999999999 |
| ## Batch [200] Train-accuracy=0.98995 |
| ## Batch [300] Train-accuracy=0.990600000000001 |
| ## Batch [400] Train-accuracy=0.991325000000002 |
| ## [5] Train-accuracy=0.991523809523812 |
| </code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">print</span><span class="p">(</span><span class="nf">proc.time</span><span class="p">()</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">tic</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## user system elapsed |
| ## 9.288 1.680 6.889 |
| </code></pre></div> |
| <p>By using a GPU processor, we significantly speed up training! |
| Now, we can submit the result to Kaggle to see the improvement of our ranking!</p> |
| <div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">preds</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test.array</span><span class="p">)</span><span class="w"> |
| </span><span class="n">pred.label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">max.col</span><span class="p">(</span><span class="n">t</span><span class="p">(</span><span class="n">preds</span><span class="p">))</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="w"> |
| </span><span class="n">submission</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data.frame</span><span class="p">(</span><span class="n">ImageId</span><span class="o">=</span><span class="m">1</span><span class="o">:</span><span class="n">ncol</span><span class="p">(</span><span class="n">test</span><span class="p">),</span><span class="w"> </span><span class="n">Label</span><span class="o">=</span><span class="n">pred.label</span><span class="p">)</span><span class="w"> |
| </span><span class="n">write.csv</span><span class="p">(</span><span class="n">submission</span><span class="p">,</span><span class="w"> </span><span class="n">file</span><span class="o">=</span><span class="s1">'submission.csv'</span><span class="p">,</span><span class="w"> </span><span class="n">row.names</span><span class="o">=</span><span class="kc">FALSE</span><span class="p">,</span><span class="w"> </span><span class="n">quote</span><span class="o">=</span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div> |
| <p><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/knitr/mnistCompetition-kaggle-submission.png" alt=""></p> |
| |
| <h2 id="next-steps">Next Steps</h2> |
| |
| <ul> |
| <li><a href="https://mxnet.io/tutorials/r/charRnnModel.html">Character Language Model using RNN</a></li> |
| </ul> |
| </div> |
| </div> |
| |
| </div> |
| </div> |
| |
| </article> |
| |
| </main><footer class="site-footer h-card"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-4"> |
| <h4 class="footer-category-title">Resources</h4> |
| <ul class="contact-list"> |
| <li><a href="/versions/1.9.1/community/contribute#mxnet-dev-communications">Mailing lists</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/labels/Roadmap">Github Roadmap</a></li> |
| <li><a href="https://medium.com/apache-mxnet">Blog</a></li> |
| <li><a href="https://discuss.mxnet.io">Forum</a></li> |
| <li><a href="/versions/1.9.1/community/contribute">Contribute</a></li> |
| </ul> |
| </div> |
| |
| <div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul> |
| </div> |
| |
| <div class="col-4 footer-text"> |
| <p>A flexible and efficient library for deep learning.</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| <footer class="site-footer2"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3"> |
| <img src="/versions/1.9.1/assets/img/apache_incubator_logo.png" class="footer-logo col-2"> |
| </div> |
| <div class="footer-bottom-warning col-9"> |
| <p>Apache MXNet is an effort undergoing incubation at <a href="http://www.apache.org/">The Apache Software Foundation</a> (ASF), <span |
| style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required |
| of all newly accepted projects until a further review indicates that the infrastructure, |
| communications, and decision making process have stabilized in a manner consistent with other |
| successful ASF projects. While incubation status is not necessarily a reflection of the completeness |
| or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p><p>"Copyright © 2017-2022, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache |
| feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the |
| Apache Software Foundation."</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| |
| |
| |
| |
| </body> |
| |
| </html> |