blob: 1ca77ba3bb690d7438ce257f9e428233f084c98c [file] [log] [blame]
<!DOCTYPE html>
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Five Minutes Neural Network | Apache MXNet</title>
<meta name="generator" content="Jekyll v3.8.6" />
<meta property="og:title" content="Five Minutes Neural Network" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/1.8.0/api/r/docs/tutorials/five_minutes_neural_network" />
<meta property="og:url" content="https://mxnet.apache.org/versions/1.8.0/api/r/docs/tutorials/five_minutes_neural_network" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/1.8.0/api/r/docs/tutorials/five_minutes_neural_network","@type":"WebPage","description":"A flexible and efficient library for deep learning.","headline":"Five Minutes Neural Network","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<script src="https://medium-widget.pixelpoint.io/widget.js"></script>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link rel="stylesheet" href="/versions/1.8.0/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.8.0/feed.xml" title="Apache MXNet" /><script>
if(!(window.doNotTrack === "1" || navigator.doNotTrack === "1" || navigator.doNotTrack === "yes" || navigator.msDoNotTrack === "1")) {
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
}
</script>
<script src="/versions/1.8.0/assets/js/jquery-3.3.1.min.js"></script><script src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js" defer></script>
<script src="/versions/1.8.0/assets/js/globalSearch.js" defer></script>
<script src="/versions/1.8.0/assets/js/clipboard.js" defer></script>
<script src="/versions/1.8.0/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/1.8.0/"><img
src="/versions/1.8.0/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">1.8.0</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/1.8.0/get_started">Get Started</a>
<a class="page-link" href="/versions/1.8.0/blog">Blog</a>
<a class="page-link" href="/versions/1.8.0/features">Features</a>
<a class="page-link" href="/versions/1.8.0/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/1.8.0/api">Docs & Tutorials</a>
<a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a>
<div class="dropdown">
<span class="dropdown-header">1.8.0
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a href="/">master</a>
<a class="dropdown-option-active" href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Five Minutes Neural Network</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<div class="docs-card docs-side">
<ul>
<div class="docs-action-btn">
<a href="/versions/1.8.0/api/r"> <img src="/versions/1.8.0/assets/img/compass.svg"
class="docs-logo-docs">R Guide <span
class="span-accented"></span></a>
</div>
<div class="docs-action-btn">
<a href="/versions/1.8.0/api/r/docs/tutorials"> <img
src="/versions/1.8.0/assets/img/video-tutorial.svg" class="docs-logo-docs">R
Tutorials <span class="span-accented"></span></a>
</div>
<div class="docs-action-btn">
<a href="/versions/1.8.0/api/r/docs/api/R-package/build/mxnet-r-reference-manual.pdf"> <img src="/versions/1.8.0/assets/img/api.svg"
class="docs-logo-docs">R API Reference
<span class="span-accented"></span></a>
</div>
<!-- Let's show the list of tutorials -->
<br>
<h3>Tutorials</h3>
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/callback_function">Callback Function</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/char_rnn_model">Char RNN Model</a></li>
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/classify_real_image_with_pretrained_model">Classify Images with a PreTrained Model</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/custom_iterator">Custom Iterator Tutorial</a></li>
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/custom_loss_function">Custom Loss Function</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/five_minutes_neural_network">Five Minutes Neural Network</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/mnist_competition">MNIST Competition</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/multi_dim_lstm">LSTM Time Series</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/ndarray">NDArray</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.8.0/api/r/docs/tutorials/symbol">NDArray</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="develop-a-neural-network-with-mxnet-in-five-minutes">Develop a Neural Network with MXNet in Five Minutes</h1>
<p>This tutorial is designed for new users of the <code>mxnet</code> package for R. It shows how to construct a neural network to do regression in 5 minutes. It shows how to perform classification and regression tasks, respectively. The data we use is in the <code>mlbench</code> package. Instructions to install R and MXNet&#39;s R package in different environments can be found <a href="/get_started?version=master&platform=linux&language=r&environ=pip&processor=cpu">here</a>.</p>
<h2 id="classification">Classification</h2>
<div class="highlight"><pre><code class="language-" data-lang=""> ## Loading required package: mlbench
</code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="o">!</span><span class="n">require</span><span class="p">(</span><span class="n">mlbench</span><span class="p">))</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">install.packages</span><span class="p">(</span><span class="s1">'mlbench'</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Loading required package: mxnet
</code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">require</span><span class="p">(</span><span class="n">mxnet</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Loading required datasets
</code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">data</span><span class="p">(</span><span class="n">Sonar</span><span class="p">,</span><span class="w"> </span><span class="n">package</span><span class="o">=</span><span class="s2">"mlbench"</span><span class="p">)</span><span class="w">
</span><span class="n">Sonar</span><span class="p">[,</span><span class="m">61</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">Sonar</span><span class="p">[,</span><span class="m">61</span><span class="p">])</span><span class="m">-1</span><span class="w">
</span><span class="n">train.ind</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="m">50</span><span class="p">,</span><span class="w"> </span><span class="m">100</span><span class="o">:</span><span class="m">150</span><span class="p">)</span><span class="w">
</span><span class="n">train.x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data.matrix</span><span class="p">(</span><span class="n">Sonar</span><span class="p">[</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="m">60</span><span class="p">])</span><span class="w">
</span><span class="n">train.y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">Sonar</span><span class="p">[</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">61</span><span class="p">]</span><span class="w">
</span><span class="n">test.x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data.matrix</span><span class="p">(</span><span class="n">Sonar</span><span class="p">[</span><span class="o">-</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="m">60</span><span class="p">])</span><span class="w">
</span><span class="n">test.y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">Sonar</span><span class="p">[</span><span class="o">-</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">61</span><span class="p">]</span><span class="w">
</span></code></pre></div>
<p>We are going to use a multi-layer perceptron as our classifier. In <code>mxnet</code>, we have a function called <code>mx.mlp</code> for building a general multi-layer neural network to do classification or regression.</p>
<p><code>mx.mlp</code> requires the following parameters:</p>
<ul>
<li>Training data and label</li>
<li>Number of hidden nodes in each hidden layer</li>
<li>Number of nodes in the output layer</li>
<li>Type of the activation</li>
<li>Type of the output loss</li>
<li>The device to train (GPU or CPU)</li>
<li>Other parameters for <code>mx.model.FeedForward.create</code></li>
</ul>
<p>The following code shows an example usage of <code>mx.mlp</code>:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">mx.set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w">
</span><span class="n">model</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.mlp</span><span class="p">(</span><span class="n">train.x</span><span class="p">,</span><span class="w"> </span><span class="n">train.y</span><span class="p">,</span><span class="w"> </span><span class="n">hidden_node</span><span class="o">=</span><span class="m">10</span><span class="p">,</span><span class="w"> </span><span class="n">out_node</span><span class="o">=</span><span class="m">2</span><span class="p">,</span><span class="w"> </span><span class="n">out_activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span><span class="w">
</span><span class="n">num.round</span><span class="o">=</span><span class="m">20</span><span class="p">,</span><span class="w"> </span><span class="n">array.batch.size</span><span class="o">=</span><span class="m">15</span><span class="p">,</span><span class="w"> </span><span class="n">learning.rate</span><span class="o">=</span><span class="m">0.07</span><span class="p">,</span><span class="w"> </span><span class="n">momentum</span><span class="o">=</span><span class="m">0.9</span><span class="p">,</span><span class="w">
</span><span class="n">eval.metric</span><span class="o">=</span><span class="n">mx.metric.accuracy</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-accuracy=0.488888888888889
## [2] Train-accuracy=0.514285714285714
## [3] Train-accuracy=0.514285714285714
## [4] Train-accuracy=0.514285714285714
## [5] Train-accuracy=0.514285714285714
## [6] Train-accuracy=0.523809523809524
## [7] Train-accuracy=0.619047619047619
## [8] Train-accuracy=0.695238095238095
## [9] Train-accuracy=0.695238095238095
## [10] Train-accuracy=0.761904761904762
## [11] Train-accuracy=0.828571428571429
## [12] Train-accuracy=0.771428571428571
## [13] Train-accuracy=0.742857142857143
## [14] Train-accuracy=0.733333333333333
## [15] Train-accuracy=0.771428571428571
## [16] Train-accuracy=0.847619047619048
## [17] Train-accuracy=0.857142857142857
## [18] Train-accuracy=0.838095238095238
## [19] Train-accuracy=0.838095238095238
## [20] Train-accuracy=0.838095238095238
</code></pre></div>
<p>Note that <code>mx.set.seed</code> controls the random process in <code>mxnet</code>. You can see the accuracy in each round during training. It&#39;s also easy to make predictions and evaluate.</p>
<p>To get an idea of what is happening, view the computation graph from R:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">graph.viz</span><span class="p">(</span><span class="n">model</span><span class="o">$</span><span class="n">symbol</span><span class="p">)</span><span class="w">
</span></code></pre></div>
<p><a href="https://github.com/dmlc/mxnet"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/knitr/graph.computation.png"></a></p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">preds</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test.x</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Auto detect layout of input matrix, use rowmajor.
</code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">pred.label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">max.col</span><span class="p">(</span><span class="n">t</span><span class="p">(</span><span class="n">preds</span><span class="p">))</span><span class="m">-1</span><span class="w">
</span><span class="n">table</span><span class="p">(</span><span class="n">pred.label</span><span class="p">,</span><span class="w"> </span><span class="n">test.y</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## test.y
## pred.label 0 1
## 0 24 14
## 1 36 33
</code></pre></div>
<p>Note for that for multi-class predictions, mxnet outputs <code>nclass</code> x <code>nexamples</code>, with each row corresponding to the probability of the class.</p>
<h2 id="regression">Regression</h2>
<p>Again, let us preprocess the data:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">data</span><span class="p">(</span><span class="n">BostonHousing</span><span class="p">,</span><span class="w"> </span><span class="n">package</span><span class="o">=</span><span class="s2">"mlbench"</span><span class="p">)</span><span class="w">
</span><span class="n">train.ind</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">seq</span><span class="p">(</span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="m">506</span><span class="p">,</span><span class="w"> </span><span class="m">3</span><span class="p">)</span><span class="w">
</span><span class="n">train.x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data.matrix</span><span class="p">(</span><span class="n">BostonHousing</span><span class="p">[</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">-14</span><span class="p">])</span><span class="w">
</span><span class="n">train.y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">BostonHousing</span><span class="p">[</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">14</span><span class="p">]</span><span class="w">
</span><span class="n">test.x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data.matrix</span><span class="p">(</span><span class="n">BostonHousing</span><span class="p">[</span><span class="o">-</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">-14</span><span class="p">])</span><span class="w">
</span><span class="n">test.y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">BostonHousing</span><span class="p">[</span><span class="o">-</span><span class="n">train.ind</span><span class="p">,</span><span class="w"> </span><span class="m">14</span><span class="p">]</span><span class="w">
</span></code></pre></div>
<p>Although we can use <code>mx.mlp</code> again to do regression by changing the <code>out_activation</code>, this time we are going to introduce a flexible way to configure neural networks in <code>mxnet</code>. Configuration is done by the &quot;Symbol&quot; system in <code>mxnet</code>. The Symbol system takes care of the links among nodes, activation, dropout ratio, etc. Configure a multi-layer neural network as follows:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="c1"># Define the input data</span><span class="w">
</span><span class="n">data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.symbol.Variable</span><span class="p">(</span><span class="s2">"data"</span><span class="p">)</span><span class="w">
</span><span class="c1"># A fully connected hidden layer</span><span class="w">
</span><span class="c1"># data: input source</span><span class="w">
</span><span class="c1"># num_hidden: number of neurons in this hidden layer</span><span class="w">
</span><span class="n">fc1</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.symbol.FullyConnected</span><span class="p">(</span><span class="n">data</span><span class="p">,</span><span class="w"> </span><span class="n">num_hidden</span><span class="o">=</span><span class="m">1</span><span class="p">)</span><span class="w">
</span><span class="c1"># Use linear regression for the output layer</span><span class="w">
</span><span class="n">lro</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.symbol.LinearRegressionOutput</span><span class="p">(</span><span class="n">fc1</span><span class="p">)</span><span class="w">
</span></code></pre></div>
<p>What matters for a regression task is mainly the last function. It enables the new network to optimize for squared loss. Now let&#39;s train on this simple data set. In this configuration, we dropped the hidden layer so that the input layer is directly connected to the output layer.</p>
<p>Next, make prediction with this structure and other parameters with <code>mx.model.FeedForward.create</code>:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">mx.set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w">
</span><span class="n">model</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.model.FeedForward.create</span><span class="p">(</span><span class="n">lro</span><span class="p">,</span><span class="w"> </span><span class="n">X</span><span class="o">=</span><span class="n">train.x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="o">=</span><span class="n">train.y</span><span class="p">,</span><span class="w">
</span><span class="n">ctx</span><span class="o">=</span><span class="n">mx.cpu</span><span class="p">(),</span><span class="w"> </span><span class="n">num.round</span><span class="o">=</span><span class="m">50</span><span class="p">,</span><span class="w"> </span><span class="n">array.batch.size</span><span class="o">=</span><span class="m">20</span><span class="p">,</span><span class="w">
</span><span class="n">learning.rate</span><span class="o">=</span><span class="m">2e-6</span><span class="p">,</span><span class="w"> </span><span class="n">momentum</span><span class="o">=</span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">eval.metric</span><span class="o">=</span><span class="n">mx.metric.rmse</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Auto detect layout of input matrix, use rowmajor.
## Start training with 1 devices
## [1] Train-rmse=16.063282524034
## [2] Train-rmse=12.2792375712573
## [3] Train-rmse=11.1984634005885
## [4] Train-rmse=10.2645236892904
## [5] Train-rmse=9.49711005504284
## [6] Train-rmse=9.07733734175182
## [7] Train-rmse=9.07884450847991
## [8] Train-rmse=9.10463850277417
## [9] Train-rmse=9.03977049028532
## [10] Train-rmse=8.96870685004475
## [11] Train-rmse=8.93113287361574
## [12] Train-rmse=8.89937257821847
## [13] Train-rmse=8.87182096922953
## [14] Train-rmse=8.84476075083586
## [15] Train-rmse=8.81464673014974
## [16] Train-rmse=8.78672567900196
## [17] Train-rmse=8.76265872846474
## [18] Train-rmse=8.73946101419974
## [19] Train-rmse=8.71651926303267
## [20] Train-rmse=8.69457600919277
## [21] Train-rmse=8.67354928674563
## [22] Train-rmse=8.65328755392436
## [23] Train-rmse=8.63378039680078
## [24] Train-rmse=8.61488162586984
## [25] Train-rmse=8.5965105183022
## [26] Train-rmse=8.57868133563275
## [27] Train-rmse=8.56135851937663
## [28] Train-rmse=8.5444819772098
## [29] Train-rmse=8.52802114610432
## [30] Train-rmse=8.5119504512622
## [31] Train-rmse=8.49624261719241
## [32] Train-rmse=8.48087453238701
## [33] Train-rmse=8.46582689119887
## [34] Train-rmse=8.45107881002491
## [35] Train-rmse=8.43661331401712
## [36] Train-rmse=8.42241575909639
## [37] Train-rmse=8.40847217331365
## [38] Train-rmse=8.39476931796395
## [39] Train-rmse=8.38129658373974
## [40] Train-rmse=8.36804269059018
## [41] Train-rmse=8.35499817678397
## [42] Train-rmse=8.34215505742154
## [43] Train-rmse=8.32950441908131
## [44] Train-rmse=8.31703985777311
## [45] Train-rmse=8.30475363906755
## [46] Train-rmse=8.29264031506106
## [47] Train-rmse=8.28069372820073
## [48] Train-rmse=8.26890902770415
## [49] Train-rmse=8.25728089053853
## [50] Train-rmse=8.24580511500735
</code></pre></div>
<p>It&#39;s also easy to make a prediction and evaluate it:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">preds</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">test.x</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Auto detect layout of input matrix, use rowmajor..
</code></pre></div><div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="nf">sqrt</span><span class="p">(</span><span class="n">mean</span><span class="p">((</span><span class="n">preds</span><span class="o">-</span><span class="n">test.y</span><span class="p">)</span><span class="o">^</span><span class="m">2</span><span class="p">))</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## [1] 7.800502
</code></pre></div>
<p>Currently, we have four predefined metrics: &quot;accuracy&quot;, &quot;rmse&quot;, &quot;mae&quot;, and &quot;rmsle&quot;. MXNet provides the interface for defining your own metrics:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">demo.metric.mae</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.metric.custom</span><span class="p">(</span><span class="s2">"mae"</span><span class="p">,</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">pred</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.reshape</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w">
</span><span class="n">res</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.mean</span><span class="p">(</span><span class="n">mx.nd.abs</span><span class="p">(</span><span class="n">label</span><span class="o">-</span><span class="n">pred</span><span class="p">))</span><span class="w">
</span><span class="nf">return</span><span class="p">(</span><span class="n">res</span><span class="p">)</span><span class="w">
</span><span class="p">})</span><span class="w">
</span></code></pre></div>
<p>This is an example of the mean absolute error metric. Simply plug it into the training function:</p>
<div class="highlight"><pre><code class="language-r" data-lang="r"><span class="w"> </span><span class="n">mx.set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w">
</span><span class="n">model</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.model.FeedForward.create</span><span class="p">(</span><span class="n">lro</span><span class="p">,</span><span class="w"> </span><span class="n">X</span><span class="o">=</span><span class="n">train.x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="o">=</span><span class="n">train.y</span><span class="p">,</span><span class="w">
</span><span class="n">ctx</span><span class="o">=</span><span class="n">mx.cpu</span><span class="p">(),</span><span class="w"> </span><span class="n">num.round</span><span class="o">=</span><span class="m">50</span><span class="p">,</span><span class="w"> </span><span class="n">array.batch.size</span><span class="o">=</span><span class="m">20</span><span class="p">,</span><span class="w">
</span><span class="n">learning.rate</span><span class="o">=</span><span class="m">2e-6</span><span class="p">,</span><span class="w"> </span><span class="n">momentum</span><span class="o">=</span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">eval.metric</span><span class="o">=</span><span class="n">demo.metric.mae</span><span class="p">)</span><span class="w">
</span></code></pre></div><div class="highlight"><pre><code class="language-" data-lang=""> ## Auto detect layout of input matrix, use rowmajor.
## Start training with 1 devices
## [1] Train-mae=14.953625731998
## [2] Train-mae=11.4802955521478
## [3] Train-mae=8.50700579749213
## [4] Train-mae=7.30591265360514
## [5] Train-mae=7.38049803839789
## [6] Train-mae=7.36036252975464
## [7] Train-mae=7.06519222259521
## [8] Train-mae=6.9962231847975
## [9] Train-mae=6.96296903822157
## [10] Train-mae=6.9046172036065
## [11] Train-mae=6.87867620256212
## [12] Train-mae=6.85872554779053
## [13] Train-mae=6.81936407089233
## [14] Train-mae=6.79135354359945
## [15] Train-mae=6.77438741260105
## [16] Train-mae=6.75365140702989
## [17] Train-mae=6.73369296391805
## [18] Train-mae=6.71600982877943
## [19] Train-mae=6.69932826360067
## [20] Train-mae=6.6852519777086
## [21] Train-mae=6.67343420452542
## [22] Train-mae=6.66315894656711
## [23] Train-mae=6.65314838621351
## [24] Train-mae=6.64388704299927
## [25] Train-mae=6.63480265935262
## [26] Train-mae=6.62583245171441
## [27] Train-mae=6.61697626113892
## [28] Train-mae=6.60842116673787
## [29] Train-mae=6.60040124257406
## [30] Train-mae=6.59264140658908
## [31] Train-mae=6.58551020092434
## [32] Train-mae=6.57864215638902
## [33] Train-mae=6.57178926467896
## [34] Train-mae=6.56495311525133
## [35] Train-mae=6.55813185373942
## [36] Train-mae=6.5513252152337
## [37] Train-mae=6.54453214009603
## [38] Train-mae=6.53775374094645
## [39] Train-mae=6.53098879920112
## [40] Train-mae=6.52423816257053
## [41] Train-mae=6.51764053768582
## [42] Train-mae=6.51121346155802
## [43] Train-mae=6.5047902001275
## [44] Train-mae=6.49837123023139
## [45] Train-mae=6.49216641320123
## [46] Train-mae=6.48598252402412
## [47] Train-mae=6.4798010720147
## [48] Train-mae=6.47362396452162
## [49] Train-mae=6.46745183732775
## [50] Train-mae=6.46128723356459
</code></pre></div>
<p>Congratulations! You&#39;ve learned the basics for using MXNet in R. To learn how to use MXNet&#39;s advanced features, see the other tutorials.</p>
<h2 id="next-steps">Next Steps</h2>
<ul>
<li><a href="https://mxnet.io/tutorials/r/classifyRealImageWithPretrainedModel.html">Classify Real-World Images with Pre-trained Model</a></li>
<li><a href="https://mxnet.io/tutorials/r/mnistCompetition.html">Handwritten Digits Classification Competition</a></li>
<li><a href="https://mxnet.io/tutorials/r/charRnnModel.html">Character Language Model using RNN</a></li>
</ul>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/1.8.0/community/contribute#mxnet-dev-communications">Mailing lists</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/labels/Roadmap">Github Roadmap</a></li>
<li><a href="https://discuss.mxnet.io">MXNet Discuss forum</a></li>
<li><a href="/versions/1.8.0/community/contribute">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.8.0/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.8.0/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.8.0/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/1.8.0/assets/img/apache_incubator_logo.png" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
<p>Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <span
style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required
of all newly accepted projects until a further review indicates that the infrastructure,
communications, and decision making process have stabilized in a manner consistent with other
successful ASF projects. While incubation status is not necessarily a reflection of the completeness
or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p><p>"Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>