blob: d120bef8b9ad5a6dda3a0a56f79c4058898e2a0e [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/master/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Custom Iterator Tutorial | Apache MXNet</title>
<meta name="generator" content="Jekyll v4.0.0" />
<meta property="og:title" content="Custom Iterator Tutorial" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/custom_iterator" />
<meta property="og:url" content="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/custom_iterator" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/master/api/r/docs/tutorials/custom_iterator","headline":"Custom Iterator Tutorial","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/master/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/master/assets/js/docsearch.min.js"></script><script src="/versions/master/assets/js/globalSearch.js" defer></script>
<script src="/versions/master/assets/js/clipboard.js" defer></script>
<script src="/versions/master/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/master/"><img
src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">master</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/master/get_started">Get Started</a>
<a class="page-link" href="/versions/master/features">Features</a>
<a class="page-link" href="/versions/master/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/master/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/master/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/master/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">master
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option-active" href="/">master</a>
<a href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Custom Iterator Tutorial</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="custom-iterator-tutorial">Custom Iterator Tutorial</h1>
<p>This tutorial provides a guideline on how to use and write custom iterators, which can very useful when having a dataset that does not fit into memory.</p>
<h2 id="getting-the-data">Getting the data</h2>
<p>The data we are going to use is the <a href="https://yann.lecun.com/exdb/mnist/">MNIST dataset</a> in CSV format, the data can be found in this <a href="https://pjreddie.com/projects/mnist-in-csv/">web</a>.</p>
<p>To download the data:</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>wget http://pjreddie.com/media/files/mnist_train.csv
wget http://pjreddie.com/media/files/mnist_test.csv
</code></pre></div></div>
<p>You’ll get two files, <code class="highlighter-rouge">mnist_train.csv</code> that contains 60.000 examples of hand written numbers and <code class="highlighter-rouge">mxnist_test.csv</code> that contains 10.000 examples. The first element of each line in the CSV is the label, which is a number between 0 and 9. The rest of the line are 784 numbers between 0 and 255, corresponding to the levels of grey of a matrix of 28x28. Therefore, each line contains an image of 28x28 pixels of a hand written number and its true label.</p>
<h2 id="custom-csv-iterator">Custom CSV Iterator</h2>
<p>Next we are going to create a custom CSV Iterator based on the <a href="https://github.com/apache/mxnet/blob/master/src/io/iter_csv.cc">C++ CSVIterator class</a>.</p>
<p>For that we are going to use the R function <code class="highlighter-rouge">mx.io.CSVIter</code> as a base class. This class has as parameters <code class="highlighter-rouge">data.csv, data.shape, batch.size</code> and two main functions, <code class="highlighter-rouge">iter.next()</code> that calls the iterator in the next batch of data and <code class="highlighter-rouge">value()</code> that returns the train data and the label.</p>
<p>The R Custom Iterator needs to inherit from the C++ data iterator class, for that we used the class <code class="highlighter-rouge">Rcpp_MXArrayDataIter</code> extracted with RCPP. Also, it needs to have the same parameters: <code class="highlighter-rouge">data.csv, data.shape, batch.size</code>. Apart from that, we can also add the field <code class="highlighter-rouge">iter</code>, which is the CSV Iterator that we are going to expand.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">CustomCSVIter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">setRefClass</span><span class="p">(</span><span class="s2">"CustomCSVIter"</span><span class="p">,</span><span class="w">
</span><span class="n">fields</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="s2">"iter"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.csv"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.shape"</span><span class="p">,</span><span class="w"> </span><span class="s2">"batch.size"</span><span class="p">),</span><span class="w">
</span><span class="n">contains</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Rcpp_MXArrayDataIter"</span><span class="p">,</span><span class="w">
</span><span class="c1">#...</span><span class="w">
</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>The next step is to initialize the class. For that we call the base <code class="highlighter-rouge">mx.io.CSVIter</code> and fill the rest of the fields.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">CustomCSVIter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">setRefClass</span><span class="p">(</span><span class="s2">"CustomCSVIter"</span><span class="p">,</span><span class="w">
</span><span class="n">fields</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="s2">"iter"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.csv"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.shape"</span><span class="p">,</span><span class="w"> </span><span class="s2">"batch.size"</span><span class="p">),</span><span class="w">
</span><span class="n">contains</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Rcpp_MXArrayDataIter"</span><span class="p">,</span><span class="w">
</span><span class="n">methods</span><span class="o">=</span><span class="nf">list</span><span class="p">(</span><span class="w">
</span><span class="n">initialize</span><span class="o">=</span><span class="k">function</span><span class="p">(</span><span class="n">iter</span><span class="p">,</span><span class="w"> </span><span class="n">data.csv</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="p">){</span><span class="w">
</span><span class="n">feature_len</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.shape</span><span class="o">*</span><span class="n">data.shape</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="w">
</span><span class="n">csv_iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.CSVIter</span><span class="p">(</span><span class="n">data.csv</span><span class="o">=</span><span class="n">data.csv</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="n">feature_len</span><span class="p">),</span><span class="w"> </span><span class="n">batch.size</span><span class="o">=</span><span class="n">batch.size</span><span class="p">)</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">csv_iter</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">data.csv</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.csv</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">data.shape</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.shape</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">batch.size</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">batch.size</span><span class="w">
</span><span class="n">.self</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="c1">#...</span><span class="w">
</span><span class="p">)</span><span class="w">
</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>So far there is no difference between the original class and the custom class. Let’s implement the function <code class="highlighter-rouge">value()</code>. In this case what we are going to do is transform the data that comes from the original class as an array of 785 numbers into a matrix of 28x28 and a label. We will also normalize the training data to be between 0 and 1.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">CustomCSVIter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">setRefClass</span><span class="p">(</span><span class="s2">"CustomCSVIter"</span><span class="p">,</span><span class="w">
</span><span class="n">fields</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="s2">"iter"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.csv"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.shape"</span><span class="p">,</span><span class="w"> </span><span class="s2">"batch.size"</span><span class="p">),</span><span class="w">
</span><span class="n">contains</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Rcpp_MXArrayDataIter"</span><span class="p">,</span><span class="w">
</span><span class="n">methods</span><span class="o">=</span><span class="nf">list</span><span class="p">(</span><span class="w">
</span><span class="n">initialize</span><span class="o">=</span><span class="k">function</span><span class="p">(</span><span class="n">iter</span><span class="p">,</span><span class="w"> </span><span class="n">data.csv</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="p">){</span><span class="w">
</span><span class="n">feature_len</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.shape</span><span class="o">*</span><span class="n">data.shape</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="w">
</span><span class="n">csv_iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.CSVIter</span><span class="p">(</span><span class="n">data.csv</span><span class="o">=</span><span class="n">data.csv</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="n">feature_len</span><span class="p">),</span><span class="w"> </span><span class="n">batch.size</span><span class="o">=</span><span class="n">batch.size</span><span class="p">)</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">csv_iter</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">data.csv</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.csv</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">data.shape</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.shape</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">batch.size</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">batch.size</span><span class="w">
</span><span class="n">.self</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="n">value</span><span class="o">=</span><span class="k">function</span><span class="p">(){</span><span class="w">
</span><span class="n">val</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">as.array</span><span class="p">(</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="o">$</span><span class="n">value</span><span class="p">()</span><span class="o">$</span><span class="n">data</span><span class="p">)</span><span class="w">
</span><span class="n">val.x</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">val</span><span class="p">[</span><span class="m">-1</span><span class="p">,]</span><span class="w">
</span><span class="n">val.y</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">val</span><span class="p">[</span><span class="m">1</span><span class="p">,]</span><span class="w">
</span><span class="n">val.x</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">val.x</span><span class="o">/</span><span class="m">255</span><span class="w">
</span><span class="nf">dim</span><span class="p">(</span><span class="n">val.x</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">ncol</span><span class="p">(</span><span class="n">val.x</span><span class="p">))</span><span class="w">
</span><span class="n">val.x</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">val.x</span><span class="p">)</span><span class="w">
</span><span class="n">val.y</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">val.y</span><span class="p">)</span><span class="w">
</span><span class="nf">list</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">val.x</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="o">=</span><span class="n">val.y</span><span class="p">)</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="c1">#...</span><span class="w">
</span><span class="p">)</span><span class="w">
</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>Finally we are going to add the rest of the functions needed for the training to work correctly. The final <code class="highlighter-rouge">CustomCSVIter</code> looks like this:</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">CustomCSVIter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">setRefClass</span><span class="p">(</span><span class="s2">"CustomCSVIter"</span><span class="p">,</span><span class="w">
</span><span class="n">fields</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="s2">"iter"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.csv"</span><span class="p">,</span><span class="w"> </span><span class="s2">"data.shape"</span><span class="p">,</span><span class="w"> </span><span class="s2">"batch.size"</span><span class="p">),</span><span class="w">
</span><span class="n">contains</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Rcpp_MXArrayDataIter"</span><span class="p">,</span><span class="w">
</span><span class="n">methods</span><span class="o">=</span><span class="nf">list</span><span class="p">(</span><span class="w">
</span><span class="n">initialize</span><span class="o">=</span><span class="k">function</span><span class="p">(</span><span class="n">iter</span><span class="p">,</span><span class="w"> </span><span class="n">data.csv</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="p">){</span><span class="w">
</span><span class="n">feature_len</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.shape</span><span class="o">*</span><span class="n">data.shape</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="w">
</span><span class="n">csv_iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.CSVIter</span><span class="p">(</span><span class="n">data.csv</span><span class="o">=</span><span class="n">data.csv</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="n">feature_len</span><span class="p">),</span><span class="w"> </span><span class="n">batch.size</span><span class="o">=</span><span class="n">batch.size</span><span class="p">)</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">csv_iter</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">data.csv</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.csv</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">data.shape</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.shape</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">batch.size</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">batch.size</span><span class="w">
</span><span class="n">.self</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="n">value</span><span class="o">=</span><span class="k">function</span><span class="p">(){</span><span class="w">
</span><span class="n">val</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">as.array</span><span class="p">(</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="o">$</span><span class="n">value</span><span class="p">()</span><span class="o">$</span><span class="n">data</span><span class="p">)</span><span class="w">
</span><span class="n">val.x</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">val</span><span class="p">[</span><span class="m">-1</span><span class="p">,]</span><span class="w">
</span><span class="n">val.y</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">val</span><span class="p">[</span><span class="m">1</span><span class="p">,]</span><span class="w">
</span><span class="n">val.x</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">val.x</span><span class="o">/</span><span class="m">255</span><span class="w">
</span><span class="nf">dim</span><span class="p">(</span><span class="n">val.x</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="p">,</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">ncol</span><span class="p">(</span><span class="n">val.x</span><span class="p">))</span><span class="w">
</span><span class="n">val.x</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">val.x</span><span class="p">)</span><span class="w">
</span><span class="n">val.y</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">val.y</span><span class="p">)</span><span class="w">
</span><span class="nf">list</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">val.x</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="o">=</span><span class="n">val.y</span><span class="p">)</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="n">iter.next</span><span class="o">=</span><span class="k">function</span><span class="p">(){</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="o">$</span><span class="n">iter.next</span><span class="p">()</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="n">reset</span><span class="o">=</span><span class="k">function</span><span class="p">(){</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="o">$</span><span class="n">reset</span><span class="p">()</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="n">num.pad</span><span class="o">=</span><span class="k">function</span><span class="p">(){</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="o">$</span><span class="n">num.pad</span><span class="p">()</span><span class="w">
</span><span class="p">},</span><span class="w">
</span><span class="n">finalize</span><span class="o">=</span><span class="k">function</span><span class="p">(){</span><span class="w">
</span><span class="n">.self</span><span class="o">$</span><span class="n">iter</span><span class="o">$</span><span class="n">finalize</span><span class="p">()</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="p">)</span><span class="w">
</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>To call the class we can just do:</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch.size</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">100</span><span class="w">
</span><span class="n">train.iter</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">CustomCSVIter</span><span class="o">$</span><span class="n">new</span><span class="p">(</span><span class="n">iter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w"> </span><span class="n">data.csv</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"mnist_train.csv"</span><span class="p">,</span><span class="w"> </span><span class="n">data.shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">28</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<h2 id="conclusion">Conclusion</h2>
<p>We have shown how to create a custom CSV Iterator by extending the class <code class="highlighter-rouge">mx.io.CSVIter</code>. In our class, we iteratively read from a CSV file a batch of data that will be transformed and then processed in the stochastic gradient descent optimization. That way, we are able to manage CSV files that are bigger than the memory of the machine we are using.</p>
<p>Based of this custom iterator, we can also create data loaders that internally transform or expand the data, allowing to manage files of any size.</p>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/master/community#stay-connected">Mailing lists</a></li>
<li><a href="/versions/master/community#github-issues">Github Issues</a></li>
<li><a href="https://github.com/apache/mxnet/projects">Projects</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/master/community">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/master/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>