blob: 4c5f0ff619501f2ab4fa1a1160bd65b267f0723d [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<title>Custom Iterator Tutorial — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ''
};
</script>
<script src="../../_static/jquery-1.11.1.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></head>
<body role="document"><!-- Previous Navbar Layout
<div class="navbar navbar-default navbar-fixed-top">
<div class="container">
<div class="navbar-header">
<button type="button" class="navbar-toggle collapsed" data-toggle="collapse" data-target="#navbar" aria-expanded="false" aria-controls="navbar">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a href="../../" class="navbar-brand">
<img src="http://data.mxnet.io/theme/mxnet.png">
</a>
</div>
<div id="navbar" class="navbar-collapse collapse">
<ul id="navbar" class="navbar navbar-left">
<li> <a href="../../get_started/index.html">Get Started</a> </li>
<li> <a href="../../tutorials/index.html">Tutorials</a> </li>
<li> <a href="../../how_to/index.html">How To</a> </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Packages <span class="caret"></span></a>
<ul class="dropdown-menu">
<li><a href="../../packages/python/index.html">
Python
</a></li>
<li><a href="../../packages/r/index.html">
R
</a></li>
<li><a href="../../packages/julia/index.html">
Julia
</a></li>
<li><a href="../../packages/c++/index.html">
C++
</a></li>
<li><a href="../../packages/scala/index.html">
Scala
</a></li>
<li><a href="../../packages/perl/index.html">
Perl
</a></li>
</ul>
</li>
<li> <a href="../../system/index.html">System</a> </li>
<li>
<form class="" role="search" action="../../search.html" method="get" autocomplete="off">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input type="text" name="q" class="form-control" placeholder="Search">
</div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form> </li>
</ul>
<ul id="navbar" class="navbar navbar-right">
<li> <a href="../../index.html"><span class="flag-icon flag-icon-us"></span></a> </li>
<li> <a href="../..//zh/index.html"><span class="flag-icon flag-icon-cn"></span></a> </li>
</ul>
</div>
</div>
</div>
Previous Navbar Layout End -->
<div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="http://data.mxnet.io/theme/mxnet.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="../../get_started/install.html">Install</a>
<a class="main-nav-link" href="../../tutorials/index.html">Tutorials</a>
<a class="main-nav-link" href="../../how_to/index.html">How To</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="../../api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="../../api/scala/index.html">Scala</a></li>
<li><a class="main-nav-link" href="../../api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="../../api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="../../api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="../../api/perl/index.html">Perl</a></li>
</ul>
</span>
<a class="main-nav-link" href="../../architecture/index.html">Architecture</a>
<!-- <a class="main-nav-link" href="../../community/index.html">Community</a> -->
<a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">Versions(master)<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a class="main-nav-link" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu dropdown-menu-right" id="burgerMenu">
<li><a href="../../get_started/install.html">Install</a></li>
<li><a href="../../tutorials/index.html">Tutorials</a></li>
<li><a href="../../how_to/index.html">How To</a></li>
<li class="dropdown-submenu">
<a href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a href="../../api/python/index.html" tabindex="-1">Python</a>
</li>
<li><a href="../../api/scala/index.html" tabindex="-1">Scala</a>
</li>
<li><a href="../../api/r/index.html" tabindex="-1">R</a>
</li>
<li><a href="../../api/julia/index.html" tabindex="-1">Julia</a>
</li>
<li><a href="../../api/c++/index.html" tabindex="-1">C++</a>
</li>
<li><a href="../../api/perl/index.html" tabindex="-1">Perl</a>
</li>
</ul>
</li>
<li><a href="../../architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://github.com/dmlc/mxnet">Github</a></li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">Versions(master)</a><ul class="dropdown-menu"><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/>v0.10.14</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/0.10/index.html>0.10</a></li><li><a tabindex="-1" href=http://mxnet.incubator.apache.org/test/versions/master/index.html>master</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="section" id="custom-iterator-tutorial">
<span id="custom-iterator-tutorial"></span><h1>Custom Iterator Tutorial<a class="headerlink" href="#custom-iterator-tutorial" title="Permalink to this headline"></a></h1>
<p>This tutorial provides a guideline on how to use and write custom iterators, which can very useful when having a dataset that does not fit into memory.</p>
<div class="section" id="getting-the-data">
<span id="getting-the-data"></span><h2>Getting the data<a class="headerlink" href="#getting-the-data" title="Permalink to this headline"></a></h2>
<p>The data we are going to use is the <a class="reference external" href="http://yann.lecun.com/exdb/mnist/">MNIST dataset</a> in CSV format, the data can be found in this <a class="reference external" href="http://pjreddie.com/projects/mnist-in-csv/">web</a>.</p>
<p>To download the data:</p>
<div class="highlight-bash"><div class="highlight"><pre><span></span>wget http://pjreddie.com/media/files/mnist_train.csv
wget http://pjreddie.com/media/files/mnist_test.csv
</pre></div>
</div>
<p>You’ll get two files, <code class="docutils literal"><span class="pre">mnist_train.csv</span></code> that contains 60.000 examples of hand written numbers and <code class="docutils literal"><span class="pre">mxnist_test.csv</span></code> that contains 10.000 examples. The first element of each line in the CSV is the label, which is a number between 0 and 9. The rest of the line are 784 numbers between 0 and 255, corresponding to the levels of grey of a matrix of 28x28. Therefore, each line contains an image of 28x28 pixels of a hand written number and its true label.</p>
</div>
<div class="section" id="custom-csv-iterator">
<span id="custom-csv-iterator"></span><h2>Custom CSV Iterator<a class="headerlink" href="#custom-csv-iterator" title="Permalink to this headline"></a></h2>
<p>Next we are going to create a custom CSV Iterator based on the <a class="reference external" href="https://github.com/dmlc/mxnet/blob/master/src/io/iter_csv.cc">C++ CSVIterator class</a>.</p>
<p>For that we are going to use the R function <code class="docutils literal"><span class="pre">mx.io.CSVIter</span></code> as a base class. This class has as parameters <code class="docutils literal"><span class="pre">data.csv,</span> <span class="pre">data.shape,</span> <span class="pre">batch.size</span></code> and two main functions, <code class="docutils literal"><span class="pre">iter.next()</span></code> that calls the iterator in the next batch of data and <code class="docutils literal"><span class="pre">value()</span></code> that returns the train data and the label.</p>
<p>The R Custom Iterator needs to inherit from the C++ data iterator class, for that we used the class <code class="docutils literal"><span class="pre">Rcpp_MXArrayDataIter</span></code> extracted with RCPP. Also, it needs to have the same parameters: <code class="docutils literal"><span class="pre">data.csv,</span> <span class="pre">data.shape,</span> <span class="pre">batch.size</span></code>. Apart from that, we can also add the field <code class="docutils literal"><span class="pre">iter</span></code>, which is the CSV Iterator that we are going to expand.</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>CustomCSVIter <span class="o"><-</span> setRefClass<span class="p">(</span><span class="s">"CustomCSVIter"</span><span class="p">,</span>
fields<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="s">"iter"</span><span class="p">,</span> <span class="s">"data.csv"</span><span class="p">,</span> <span class="s">"data.shape"</span><span class="p">,</span> <span class="s">"batch.size"</span><span class="p">),</span>
contains <span class="o">=</span> <span class="s">"Rcpp_MXArrayDataIter"</span><span class="p">,</span>
<span class="c1">#...</span>
<span class="p">)</span>
</pre></div>
</div>
<p>The next step is to initialize the class. For that we call the base <code class="docutils literal"><span class="pre">mx.io.CSVIter</span></code> and fill the rest of the fields.</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>CustomCSVIter <span class="o"><-</span> setRefClass<span class="p">(</span><span class="s">"CustomCSVIter"</span><span class="p">,</span>
fields<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="s">"iter"</span><span class="p">,</span> <span class="s">"data.csv"</span><span class="p">,</span> <span class="s">"data.shape"</span><span class="p">,</span> <span class="s">"batch.size"</span><span class="p">),</span>
contains <span class="o">=</span> <span class="s">"Rcpp_MXArrayDataIter"</span><span class="p">,</span>
methods<span class="o">=</span><span class="kt">list</span><span class="p">(</span>
initialize<span class="o">=</span><span class="kr">function</span><span class="p">(</span>iter<span class="p">,</span> data.csv<span class="p">,</span> data.shape<span class="p">,</span> batch.size<span class="p">){</span>
feature_len <span class="o"><-</span> data.shape<span class="o">*</span>data.shape <span class="o">+</span> <span class="m">1</span>
csv_iter <span class="o"><-</span> mx.io.CSVIter<span class="p">(</span>data.csv<span class="o">=</span>data.csv<span class="p">,</span> data.shape<span class="o">=</span><span class="kt">c</span><span class="p">(</span>feature_len<span class="p">),</span> batch.size<span class="o">=</span>batch.size<span class="p">)</span>
<span class="m">.</span>self<span class="o">$</span>iter <span class="o"><-</span> csv_iter
<span class="m">.</span>self<span class="o">$</span>data.csv <span class="o"><-</span> data.csv
<span class="m">.</span>self<span class="o">$</span>data.shape <span class="o"><-</span> data.shape
<span class="m">.</span>self<span class="o">$</span>batch.size <span class="o"><-</span> batch.size
<span class="m">.</span>self
<span class="p">},</span>
<span class="c1">#...</span>
<span class="p">)</span>
<span class="p">)</span>
</pre></div>
</div>
<p>So far there is no difference between the original class and the custom class. Let’s implement the function <code class="docutils literal"><span class="pre">value()</span></code>. In this case what we are going to do is transform the data that comes from the original class as an array of 785 numbers into a matrix of 28x28 and a label. We will also normalize the training data to be between 0 and 1.</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>CustomCSVIter <span class="o"><-</span> setRefClass<span class="p">(</span><span class="s">"CustomCSVIter"</span><span class="p">,</span>
fields<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="s">"iter"</span><span class="p">,</span> <span class="s">"data.csv"</span><span class="p">,</span> <span class="s">"data.shape"</span><span class="p">,</span> <span class="s">"batch.size"</span><span class="p">),</span>
contains <span class="o">=</span> <span class="s">"Rcpp_MXArrayDataIter"</span><span class="p">,</span>
methods<span class="o">=</span><span class="kt">list</span><span class="p">(</span>
initialize<span class="o">=</span><span class="kr">function</span><span class="p">(</span>iter<span class="p">,</span> data.csv<span class="p">,</span> data.shape<span class="p">,</span> batch.size<span class="p">){</span>
feature_len <span class="o"><-</span> data.shape<span class="o">*</span>data.shape <span class="o">+</span> <span class="m">1</span>
csv_iter <span class="o"><-</span> mx.io.CSVIter<span class="p">(</span>data.csv<span class="o">=</span>data.csv<span class="p">,</span> data.shape<span class="o">=</span><span class="kt">c</span><span class="p">(</span>feature_len<span class="p">),</span> batch.size<span class="o">=</span>batch.size<span class="p">)</span>
<span class="m">.</span>self<span class="o">$</span>iter <span class="o"><-</span> csv_iter
<span class="m">.</span>self<span class="o">$</span>data.csv <span class="o"><-</span> data.csv
<span class="m">.</span>self<span class="o">$</span>data.shape <span class="o"><-</span> data.shape
<span class="m">.</span>self<span class="o">$</span>batch.size <span class="o"><-</span> batch.size
<span class="m">.</span>self
<span class="p">},</span>
value<span class="o">=</span><span class="kr">function</span><span class="p">(){</span>
val <span class="o"><-</span> <span class="kp">as.array</span><span class="p">(</span><span class="m">.</span>self<span class="o">$</span>iter<span class="o">$</span>value<span class="p">()</span><span class="o">$</span>data<span class="p">)</span>
val.x <span class="o"><-</span> val<span class="p">[</span><span class="m">-1</span><span class="p">,]</span>
val.y <span class="o"><-</span> val<span class="p">[</span><span class="m">1</span><span class="p">,]</span>
val.x <span class="o"><-</span> val.x<span class="o">/</span><span class="m">255</span>
<span class="kp">dim</span><span class="p">(</span>val.x<span class="p">)</span> <span class="o"><-</span> <span class="kt">c</span><span class="p">(</span>data.shape<span class="p">,</span> data.shape<span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="kp">ncol</span><span class="p">(</span>val.x<span class="p">))</span>
val.x <span class="o"><-</span> mx.nd.array<span class="p">(</span>val.x<span class="p">)</span>
val.y <span class="o"><-</span> mx.nd.array<span class="p">(</span>val.y<span class="p">)</span>
<span class="kt">list</span><span class="p">(</span>data<span class="o">=</span>val.x<span class="p">,</span> label<span class="o">=</span>val.y<span class="p">)</span>
<span class="p">},</span>
<span class="c1">#...</span>
<span class="p">)</span>
<span class="p">)</span>
</pre></div>
</div>
<p>Finally we are going to add the rest of the functions needed for the training to work correctly. The final <code class="docutils literal"><span class="pre">CustomCSVIter</span></code> looks like this:</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>CustomCSVIter <span class="o"><-</span> setRefClass<span class="p">(</span><span class="s">"CustomCSVIter"</span><span class="p">,</span>
fields<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="s">"iter"</span><span class="p">,</span> <span class="s">"data.csv"</span><span class="p">,</span> <span class="s">"data.shape"</span><span class="p">,</span> <span class="s">"batch.size"</span><span class="p">),</span>
contains <span class="o">=</span> <span class="s">"Rcpp_MXArrayDataIter"</span><span class="p">,</span>
methods<span class="o">=</span><span class="kt">list</span><span class="p">(</span>
initialize<span class="o">=</span><span class="kr">function</span><span class="p">(</span>iter<span class="p">,</span> data.csv<span class="p">,</span> data.shape<span class="p">,</span> batch.size<span class="p">){</span>
feature_len <span class="o"><-</span> data.shape<span class="o">*</span>data.shape <span class="o">+</span> <span class="m">1</span>
csv_iter <span class="o"><-</span> mx.io.CSVIter<span class="p">(</span>data.csv<span class="o">=</span>data.csv<span class="p">,</span> data.shape<span class="o">=</span><span class="kt">c</span><span class="p">(</span>feature_len<span class="p">),</span> batch.size<span class="o">=</span>batch.size<span class="p">)</span>
<span class="m">.</span>self<span class="o">$</span>iter <span class="o"><-</span> csv_iter
<span class="m">.</span>self<span class="o">$</span>data.csv <span class="o"><-</span> data.csv
<span class="m">.</span>self<span class="o">$</span>data.shape <span class="o"><-</span> data.shape
<span class="m">.</span>self<span class="o">$</span>batch.size <span class="o"><-</span> batch.size
<span class="m">.</span>self
<span class="p">},</span>
value<span class="o">=</span><span class="kr">function</span><span class="p">(){</span>
val <span class="o"><-</span> <span class="kp">as.array</span><span class="p">(</span><span class="m">.</span>self<span class="o">$</span>iter<span class="o">$</span>value<span class="p">()</span><span class="o">$</span>data<span class="p">)</span>
val.x <span class="o"><-</span> val<span class="p">[</span><span class="m">-1</span><span class="p">,]</span>
val.y <span class="o"><-</span> val<span class="p">[</span><span class="m">1</span><span class="p">,]</span>
val.x <span class="o"><-</span> val.x<span class="o">/</span><span class="m">255</span>
<span class="kp">dim</span><span class="p">(</span>val.x<span class="p">)</span> <span class="o"><-</span> <span class="kt">c</span><span class="p">(</span>data.shape<span class="p">,</span> data.shape<span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="kp">ncol</span><span class="p">(</span>val.x<span class="p">))</span>
val.x <span class="o"><-</span> mx.nd.array<span class="p">(</span>val.x<span class="p">)</span>
val.y <span class="o"><-</span> mx.nd.array<span class="p">(</span>val.y<span class="p">)</span>
<span class="kt">list</span><span class="p">(</span>data<span class="o">=</span>val.x<span class="p">,</span> label<span class="o">=</span>val.y<span class="p">)</span>
<span class="p">},</span>
iter.next<span class="o">=</span><span class="kr">function</span><span class="p">(){</span>
<span class="m">.</span>self<span class="o">$</span>iter<span class="o">$</span>iter.next<span class="p">()</span>
<span class="p">},</span>
reset<span class="o">=</span><span class="kr">function</span><span class="p">(){</span>
<span class="m">.</span>self<span class="o">$</span>iter<span class="o">$</span>reset<span class="p">()</span>
<span class="p">},</span>
num.pad<span class="o">=</span><span class="kr">function</span><span class="p">(){</span>
<span class="m">.</span>self<span class="o">$</span>iter<span class="o">$</span>num.pad<span class="p">()</span>
<span class="p">},</span>
finalize<span class="o">=</span><span class="kr">function</span><span class="p">(){</span>
<span class="m">.</span>self<span class="o">$</span>iter<span class="o">$</span>finalize<span class="p">()</span>
<span class="p">}</span>
<span class="p">)</span>
<span class="p">)</span>
</pre></div>
</div>
<p>To call the class we can just do:</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>batch.size <span class="o"><-</span> <span class="m">100</span>
train.iter <span class="o"><-</span> CustomCSVIter<span class="o">$</span>new<span class="p">(</span>iter <span class="o">=</span> <span class="kc">NULL</span><span class="p">,</span> data.csv <span class="o">=</span> <span class="s">"mnist_train.csv"</span><span class="p">,</span> data.shape <span class="o">=</span> <span class="m">28</span><span class="p">,</span> batch.size <span class="o">=</span> batch.size<span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="cnn-model">
<span id="cnn-model"></span><h2>CNN Model<a class="headerlink" href="#cnn-model" title="Permalink to this headline"></a></h2>
<p>For this tutorial we are going to use the known LeNet architecture:</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>lenet.model <span class="o"><-</span> <span class="kr">function</span><span class="p">(){</span>
data <span class="o"><-</span> mx.symbol.Variable<span class="p">(</span><span class="s">'data'</span><span class="p">)</span>
conv1 <span class="o"><-</span> mx.symbol.Convolution<span class="p">(</span>data<span class="o">=</span>data<span class="p">,</span> kernel<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="m">5</span><span class="p">,</span><span class="m">5</span><span class="p">),</span> num_filter<span class="o">=</span><span class="m">20</span><span class="p">)</span> <span class="c1">#first conv</span>
tanh1 <span class="o"><-</span> mx.symbol.Activation<span class="p">(</span>data<span class="o">=</span>conv1<span class="p">,</span> act_type<span class="o">=</span><span class="s">"tanh"</span><span class="p">)</span>
pool1 <span class="o"><-</span> mx.symbol.Pooling<span class="p">(</span>data<span class="o">=</span>tanh1<span class="p">,</span> pool_type<span class="o">=</span><span class="s">"max"</span><span class="p">,</span> kernel<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">),</span> stride<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">))</span>
conv2 <span class="o"><-</span> mx.symbol.Convolution<span class="p">(</span>data<span class="o">=</span>pool1<span class="p">,</span> kernel<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="m">5</span><span class="p">,</span><span class="m">5</span><span class="p">),</span> num_filter<span class="o">=</span><span class="m">50</span><span class="p">)</span><span class="c1"># second conv</span>
tanh2 <span class="o"><-</span> mx.symbol.Activation<span class="p">(</span>data<span class="o">=</span>conv2<span class="p">,</span> act_type<span class="o">=</span><span class="s">"tanh"</span><span class="p">)</span>
pool2 <span class="o"><-</span> mx.symbol.Pooling<span class="p">(</span>data<span class="o">=</span>tanh2<span class="p">,</span> pool_type<span class="o">=</span><span class="s">"max"</span><span class="p">,</span> kernel<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">),</span> stride<span class="o">=</span><span class="kt">c</span><span class="p">(</span><span class="m">2</span><span class="p">,</span><span class="m">2</span><span class="p">))</span>
flatten <span class="o"><-</span> mx.symbol.Flatten<span class="p">(</span>data<span class="o">=</span>pool2<span class="p">)</span>
fc1 <span class="o"><-</span> mx.symbol.FullyConnected<span class="p">(</span>data<span class="o">=</span>flatten<span class="p">,</span> num_hidden<span class="o">=</span><span class="m">100</span><span class="p">)</span> <span class="c1"># first fullc</span>
tanh3 <span class="o"><-</span> mx.symbol.Activation<span class="p">(</span>data<span class="o">=</span>fc1<span class="p">,</span> act_type<span class="o">=</span><span class="s">"tanh"</span><span class="p">)</span>
fc2 <span class="o"><-</span> mx.symbol.FullyConnected<span class="p">(</span>data<span class="o">=</span>tanh3<span class="p">,</span> num_hidden<span class="o">=</span><span class="m">10</span><span class="p">)</span> <span class="c1"># second fullc</span>
network <span class="o"><-</span> mx.symbol.SoftmaxOutput<span class="p">(</span>data<span class="o">=</span>fc2<span class="p">)</span> <span class="c1"># loss</span>
network
<span class="p">}</span>
network <span class="o"><-</span> lenet.model<span class="p">()</span>
</pre></div>
</div>
</div>
<div class="section" id="training-with-the-custom-iterator">
<span id="training-with-the-custom-iterator"></span><h2>Training with the Custom Iterator<a class="headerlink" href="#training-with-the-custom-iterator" title="Permalink to this headline"></a></h2>
<p>Finally, we can directly add the custom iterator as the training data source.</p>
<div class="highlight-r"><div class="highlight"><pre><span></span>model <span class="o"><-</span> mx.model.FeedForward.create<span class="p">(</span>symbol<span class="o">=</span>network<span class="p">,</span>
X<span class="o">=</span>train.iter<span class="p">,</span>
ctx<span class="o">=</span>mx.gpu<span class="p">(</span><span class="m">0</span><span class="p">),</span>
num.round<span class="o">=</span><span class="m">10</span><span class="p">,</span>
array.batch.size<span class="o">=</span>batch.size<span class="p">,</span>
learning.rate<span class="o">=</span><span class="m">0.1</span><span class="p">,</span>
momentum<span class="o">=</span><span class="m">0.9</span><span class="p">,</span>
eval.metric<span class="o">=</span>mx.metric.accuracy<span class="p">,</span>
wd<span class="o">=</span><span class="m">0.00001</span><span class="p">,</span>
batch.end.callback<span class="o">=</span>mx.callback.log.speedometer<span class="p">(</span>batch.size<span class="p">,</span> frequency <span class="o">=</span> <span class="m">100</span><span class="p">)</span>
<span class="p">)</span>
</pre></div>
</div>
<p>The last 2 iterations with a K80 GPU looks like this:</p>
<div class="highlight-bash"><div class="highlight"><pre><span></span><span class="o">[</span><span class="m">8</span><span class="o">]</span> Train-accuracy<span class="o">=</span><span class="m">0</span>.998866666666667
Batch <span class="o">[</span><span class="m">100</span><span class="o">]</span> Speed: <span class="m">15413</span>.0104454713 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.999
Batch <span class="o">[</span><span class="m">200</span><span class="o">]</span> Speed: <span class="m">16629</span>.3412459049 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.99935
Batch <span class="o">[</span><span class="m">300</span><span class="o">]</span> Speed: <span class="m">18412</span>.6900509319 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.9995
Batch <span class="o">[</span><span class="m">400</span><span class="o">]</span> Speed: <span class="m">16757</span>.2882328335 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.999425
Batch <span class="o">[</span><span class="m">500</span><span class="o">]</span> Speed: <span class="m">17116</span>.6529207406 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.99946
Batch <span class="o">[</span><span class="m">600</span><span class="o">]</span> Speed: <span class="m">19627</span>.589505195 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.99945
<span class="o">[</span><span class="m">9</span><span class="o">]</span> Train-accuracy<span class="o">=</span><span class="m">0</span>.9991
Batch <span class="o">[</span><span class="m">100</span><span class="o">]</span> Speed: <span class="m">18971</span>.5745536982 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.9992
Batch <span class="o">[</span><span class="m">200</span><span class="o">]</span> Speed: <span class="m">15554</span>.8822435383 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.99955
Batch <span class="o">[</span><span class="m">300</span><span class="o">]</span> Speed: <span class="m">18327</span>.6950115053 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.9997
Batch <span class="o">[</span><span class="m">400</span><span class="o">]</span> Speed: <span class="m">17103</span>.0705411788 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.9997
Batch <span class="o">[</span><span class="m">500</span><span class="o">]</span> Speed: <span class="m">15104</span>.8656902394 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.99974
Batch <span class="o">[</span><span class="m">600</span><span class="o">]</span> Speed: <span class="m">13818</span>.7899518255 samples/sec Train-accuracy<span class="o">=</span><span class="m">0</span>.99975
<span class="o">[</span><span class="m">10</span><span class="o">]</span> Train-accuracy<span class="o">=</span><span class="m">0</span>.99975
</pre></div>
</div>
</div>
<div class="section" id="conclusion">
<span id="conclusion"></span><h2>Conclusion<a class="headerlink" href="#conclusion" title="Permalink to this headline"></a></h2>
<p>We have shown how to create a custom CSV Iterator by extending the class <code class="docutils literal"><span class="pre">mx.io.CSVIter</span></code>. In our class, we iteratively read from a CSV file a batch of data that will be transformed and then processed in the stochastic gradient descent optimization. That way, we are able to manage CSV files that are bigger than the memory of the machine we are using.</p>
<p>Based of this custom iterator, we can also create data loaders that internally transform or expand the data, allowing to manage files of any size.</p>
</div>
</div>
<div class="container">
<div class="footer">
<p> © 2015-2017 DMLC. All rights reserved. </p>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Custom Iterator Tutorial</a><ul>
<li><a class="reference internal" href="#getting-the-data">Getting the data</a></li>
<li><a class="reference internal" href="#custom-csv-iterator">Custom CSV Iterator</a></li>
<li><a class="reference internal" href="#cnn-model">CNN Model</a></li>
<li><a class="reference internal" href="#training-with-the-custom-iterator">Training with the Custom Iterator</a></li>
<li><a class="reference internal" href="#conclusion">Conclusion</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div> <!-- pagename != index -->
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</div></body>
</html>