blob: 0508f81fd8a955d88f7298140190afe578b9893d [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/master/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>LSTM Time Series | Apache MXNet</title>
<meta name="generator" content="Jekyll v4.0.0" />
<meta property="og:title" content="LSTM Time Series" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/multi_dim_lstm" />
<meta property="og:url" content="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/multi_dim_lstm" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/master/api/r/docs/tutorials/multi_dim_lstm","headline":"LSTM Time Series","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/master/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/master/assets/js/docsearch.min.js"></script><script src="/versions/master/assets/js/globalSearch.js" defer></script>
<script src="/versions/master/assets/js/clipboard.js" defer></script>
<script src="/versions/master/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/master/"><img
src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">master</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/master/get_started">Get Started</a>
<a class="page-link" href="/versions/master/features">Features</a>
<a class="page-link" href="/versions/master/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/master/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/master/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/master/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">master
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option-active" href="/">master</a>
<a href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">LSTM Time Series</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="lstm-time-series-example">LSTM Time Series Example</h1>
<p>This tutorial shows how to use an LSTM model with multivariate data, and generate predictions from it. For demonstration purposes, we used an open source <a href="https://archive.ics.uci.edu/ml/datasets/Beijing+PM2.5+Data">pollution data</a>.
The tutorial is an illustration of how to use LSTM models with MXNet-R. We are forecasting the air pollution with data recorded at the US embassy in Beijing, China for five years.</p>
<p>Dataset Attribution:
“PM2.5 data of US Embassy in Beijing”
We want to predict pollution levels(PM2.5 concentration) in the city given the above dataset.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Dataset</span><span class="w"> </span><span class="n">description</span><span class="o">:</span><span class="w">
</span><span class="n">No</span><span class="o">:</span><span class="w"> </span><span class="n">row</span><span class="w"> </span><span class="n">number</span><span class="w">
</span><span class="n">year</span><span class="o">:</span><span class="w"> </span><span class="n">year</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="n">this</span><span class="w"> </span><span class="n">row</span><span class="w">
</span><span class="n">month</span><span class="o">:</span><span class="w"> </span><span class="n">month</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="n">this</span><span class="w"> </span><span class="n">row</span><span class="w">
</span><span class="n">day</span><span class="o">:</span><span class="w"> </span><span class="n">day</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="n">this</span><span class="w"> </span><span class="n">row</span><span class="w">
</span><span class="n">hour</span><span class="o">:</span><span class="w"> </span><span class="n">hour</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="n">this</span><span class="w"> </span><span class="n">row</span><span class="w">
</span><span class="n">pm2.5</span><span class="o">:</span><span class="w"> </span><span class="n">PM2.5</span><span class="w"> </span><span class="n">concentration</span><span class="w">
</span><span class="n">DEWP</span><span class="o">:</span><span class="w"> </span><span class="n">Dew</span><span class="w"> </span><span class="n">Point</span><span class="w">
</span><span class="n">TEMP</span><span class="o">:</span><span class="w"> </span><span class="n">Temperature</span><span class="w">
</span><span class="n">PRES</span><span class="o">:</span><span class="w"> </span><span class="n">Pressure</span><span class="w">
</span><span class="n">cbwd</span><span class="o">:</span><span class="w"> </span><span class="n">Combined</span><span class="w"> </span><span class="n">wind</span><span class="w"> </span><span class="n">direction</span><span class="w">
</span><span class="n">Iws</span><span class="o">:</span><span class="w"> </span><span class="n">Cumulated</span><span class="w"> </span><span class="n">wind</span><span class="w"> </span><span class="n">speed</span><span class="w">
</span><span class="n">Is</span><span class="o">:</span><span class="w"> </span><span class="n">Cumulated</span><span class="w"> </span><span class="n">hours</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">snow</span><span class="w">
</span><span class="n">Ir</span><span class="o">:</span><span class="w"> </span><span class="n">Cumulated</span><span class="w"> </span><span class="n">hours</span><span class="w"> </span><span class="n">of</span><span class="w"> </span><span class="n">rain</span><span class="w">
</span></code></pre></div></div>
<p>We use past PM2.5 concentration, dew point, temperature, pressure, wind speed, snow and rain to predict
PM2.5 concentration levels.</p>
<h2 id="load-and-pre-process-the-data">Load and pre-process the data</h2>
<p>The first step is to load in the data and preprocess it. It is assumed that the data has been downloaded in a .csv file: data.csv from the <a href="https://archive.ics.uci.edu/ml/datasets/Beijing+PM2.5+Data">pollution dataset</a>.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">## Loading required packages</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"readr"</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"dplyr"</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"mxnet"</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"abind"</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">## Preprocessing steps</span><span class="w">
</span><span class="n">Data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">read.csv</span><span class="p">(</span><span class="n">file</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"/Users/khedia/Downloads/data.csv"</span><span class="p">,</span><span class="w">
</span><span class="n">header</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">,</span><span class="w">
</span><span class="n">sep</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">","</span><span class="p">)</span><span class="w">
</span><span class="c1">## Extracting specific features from the dataset as variables for time series We extract</span><span class="w">
</span><span class="c1">## pollution, temperature, pressue, windspeed, snowfall and rainfall information from dataset</span><span class="w">
</span><span class="n">df</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data.frame</span><span class="p">(</span><span class="n">Data</span><span class="o">$</span><span class="n">pm2.5</span><span class="p">,</span><span class="w">
</span><span class="n">Data</span><span class="o">$</span><span class="n">DEWP</span><span class="p">,</span><span class="w">
</span><span class="n">Data</span><span class="o">$</span><span class="n">TEMP</span><span class="p">,</span><span class="w">
</span><span class="n">Data</span><span class="o">$</span><span class="n">PRES</span><span class="p">,</span><span class="w">
</span><span class="n">Data</span><span class="o">$</span><span class="n">Iws</span><span class="p">,</span><span class="w">
</span><span class="n">Data</span><span class="o">$</span><span class="n">Is</span><span class="p">,</span><span class="w">
</span><span class="n">Data</span><span class="o">$</span><span class="n">Ir</span><span class="p">)</span><span class="w">
</span><span class="n">df</span><span class="p">[</span><span class="nf">is.na</span><span class="p">(</span><span class="n">df</span><span class="p">)]</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">0</span><span class="w">
</span><span class="c1">## Now we normalise each of the feature set to a range(0,1)</span><span class="w">
</span><span class="n">df</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">matrix</span><span class="p">(</span><span class="n">as.matrix</span><span class="p">(</span><span class="n">df</span><span class="p">),</span><span class="w">
</span><span class="n">ncol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">ncol</span><span class="p">(</span><span class="n">df</span><span class="p">),</span><span class="w">
</span><span class="n">dimnames</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">)</span><span class="w">
</span><span class="n">rangenorm</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="p">(</span><span class="n">x</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="nf">min</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="o">/</span><span class="p">(</span><span class="nf">max</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="nf">min</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">df</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">apply</span><span class="p">(</span><span class="n">df</span><span class="p">,</span><span class="w"> </span><span class="m">2</span><span class="p">,</span><span class="w"> </span><span class="n">rangenorm</span><span class="p">)</span><span class="w">
</span><span class="n">df</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">t</span><span class="p">(</span><span class="n">df</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>For using multidimesional data with MXNet-R, we need to convert training data to the form
(n_dim x seq_len x num_samples). For one-to-one RNN flavours labels should be of the form (seq_len x num_samples) while for many-to-one flavour, the labels should be of the form (1 x num_samples). Please note that MXNet-R currently supports only these two flavours of RNN.
We have used n_dim = 7, seq_len = 100, and num_samples = 430 because the dataset has 430 samples, each the length of 100 timestamps, we have seven time series as input features so each input has dimesnion of seven at each time step.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_dim</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">7</span><span class="w">
</span><span class="n">seq_len</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">100</span><span class="w">
</span><span class="n">num_samples</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">430</span><span class="w">
</span><span class="c1">## extract only required data from dataset</span><span class="w">
</span><span class="n">trX</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">df</span><span class="p">[</span><span class="m">1</span><span class="o">:</span><span class="n">n_dim</span><span class="p">,</span><span class="w"> </span><span class="m">25</span><span class="o">:</span><span class="p">(</span><span class="m">24</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="p">(</span><span class="n">seq_len</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">num_samples</span><span class="p">))]</span><span class="w">
</span><span class="c1">## the label data(next PM2.5 concentration) should be one time step</span><span class="w">
</span><span class="c1">## ahead of the current PM2.5 concentration</span><span class="w">
</span><span class="n">trY</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">df</span><span class="p">[</span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="m">26</span><span class="o">:</span><span class="p">(</span><span class="m">25</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="p">(</span><span class="n">seq_len</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">num_samples</span><span class="p">))]</span><span class="w">
</span><span class="c1">## reshape the matrices in the format acceptable by MXNetR RNNs</span><span class="w">
</span><span class="n">trainX</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">trX</span><span class="w">
</span><span class="nf">dim</span><span class="p">(</span><span class="n">trainX</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">n_dim</span><span class="p">,</span><span class="w"> </span><span class="n">seq_len</span><span class="p">,</span><span class="w"> </span><span class="n">num_samples</span><span class="p">)</span><span class="w">
</span><span class="n">trainY</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">trY</span><span class="w">
</span><span class="nf">dim</span><span class="p">(</span><span class="n">trainY</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span><span class="w"> </span><span class="n">num_samples</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<h2 id="defining-and-training-the-network">Defining and training the network</h2>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch.size</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">32</span><span class="w">
</span><span class="c1"># take first 300 samples for training - remaining 100 for evaluation</span><span class="w">
</span><span class="n">train_ids</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="m">300</span><span class="w">
</span><span class="n">eval_ids</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">301</span><span class="o">:</span><span class="m">400</span><span class="w">
</span><span class="c1">## The number of samples used for training and evaluation is arbitrary. I have kept aside few</span><span class="w">
</span><span class="c1">## samples for testing purposes create dataiterators</span><span class="w">
</span><span class="n">train.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">trainX</span><span class="p">[,</span><span class="w"> </span><span class="p">,</span><span class="w"> </span><span class="n">train_ids</span><span class="p">,</span><span class="w"> </span><span class="n">drop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">],</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">trainY</span><span class="p">[,</span><span class="w"> </span><span class="n">train_ids</span><span class="p">],</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">)</span><span class="w">
</span><span class="n">eval.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">trainX</span><span class="p">[,</span><span class="w"> </span><span class="p">,</span><span class="w"> </span><span class="n">eval_ids</span><span class="p">,</span><span class="w"> </span><span class="n">drop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">],</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">trainY</span><span class="p">[,</span><span class="w"> </span><span class="n">eval_ids</span><span class="p">],</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span><span class="c1">## Create the symbol for RNN</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">rnn.graph</span><span class="p">(</span><span class="n">num_rnn_layer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w">
</span><span class="n">num_hidden</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">5</span><span class="p">,</span><span class="w">
</span><span class="n">input_size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
</span><span class="n">num_embed</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
</span><span class="n">num_decode</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w">
</span><span class="n">masking</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">,</span><span class="w">
</span><span class="n">loss_output</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"linear"</span><span class="p">,</span><span class="w">
</span><span class="n">dropout</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.2</span><span class="p">,</span><span class="w">
</span><span class="n">ignore_label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">,</span><span class="w">
</span><span class="n">cell_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"lstm"</span><span class="p">,</span><span class="w">
</span><span class="n">output_last_state</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">T</span><span class="p">,</span><span class="w">
</span><span class="n">config</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one-to-one"</span><span class="p">)</span><span class="w">
</span><span class="n">mx.metric.mse.seq</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.metric.custom</span><span class="p">(</span><span class="s2">"MSE"</span><span class="p">,</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.nd.reshape</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">)</span><span class="w">
</span><span class="n">pred</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.nd.reshape</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">)</span><span class="w">
</span><span class="n">res</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.mean</span><span class="p">(</span><span class="n">mx.nd.square</span><span class="p">(</span><span class="n">label</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">pred</span><span class="p">))</span><span class="w">
</span><span class="nf">return</span><span class="p">(</span><span class="n">as.array</span><span class="p">(</span><span class="n">res</span><span class="p">))</span><span class="w">
</span><span class="p">})</span><span class="w">
</span><span class="n">ctx</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.cpu</span><span class="p">()</span><span class="w">
</span><span class="n">initializer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.init.Xavier</span><span class="p">(</span><span class="n">rnd_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"gaussian"</span><span class="p">,</span><span class="w">
</span><span class="n">factor_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"avg"</span><span class="p">,</span><span class="w">
</span><span class="n">magnitude</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">3</span><span class="p">)</span><span class="w">
</span><span class="n">optimizer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.opt.create</span><span class="p">(</span><span class="s2">"adadelta"</span><span class="p">,</span><span class="w">
</span><span class="n">rho</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.9</span><span class="p">,</span><span class="w">
</span><span class="n">eps</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1e-05</span><span class="p">,</span><span class="w">
</span><span class="n">wd</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1e-06</span><span class="p">,</span><span class="w">
</span><span class="n">clip_gradient</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w">
</span><span class="n">rescale.grad</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="o">/</span><span class="n">batch.size</span><span class="p">)</span><span class="w">
</span><span class="n">logger</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.metric.logger</span><span class="p">()</span><span class="w">
</span><span class="n">epoch.end.callback</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="n">period</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">,</span><span class="w">
</span><span class="n">logger</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">logger</span><span class="p">)</span><span class="w">
</span><span class="c1">## train the network</span><span class="w">
</span><span class="n">system.time</span><span class="p">(</span><span class="n">model</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.model.buckets</span><span class="p">(</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w">
</span><span class="n">train.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train.data</span><span class="p">,</span><span class="w">
</span><span class="n">eval.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval.data</span><span class="p">,</span><span class="w">
</span><span class="n">num.round</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">100</span><span class="p">,</span><span class="w">
</span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">ctx</span><span class="p">,</span><span class="w">
</span><span class="n">verbose</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">,</span><span class="w">
</span><span class="n">metric</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.metric.mse.seq</span><span class="p">,</span><span class="w">
</span><span class="n">initializer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">initializer</span><span class="p">,</span><span class="w">
</span><span class="n">optimizer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">optimizer</span><span class="p">,</span><span class="w">
</span><span class="n">batch.end.callback</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
</span><span class="n">epoch.end.callback</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">epoch.end.callback</span><span class="p">))</span><span class="w">
</span></code></pre></div></div>
<p>Output:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Start training with 1 devices
[1] Train-MSE=0.197570244409144
[1] Validation-MSE=0.0153861071448773
[2] Train-MSE=0.0152517843060195
[2] Validation-MSE=0.0128299412317574
[3] Train-MSE=0.0124418652616441
[3] Validation-MSE=0.010827143676579
[4] Train-MSE=0.0105128229130059
[4] Validation-MSE=0.00940261723008007
[5] Train-MSE=0.00914482437074184
[5] Validation-MSE=0.00830172537826002
[6] Train-MSE=0.00813581114634871
[6] Validation-MSE=0.00747016374953091
[7] Train-MSE=0.00735094994306564
[7] Validation-MSE=0.00679832429159433
[8] Train-MSE=0.00672049634158611
[8] Validation-MSE=0.00623159145470709
[9] Train-MSE=0.00620287149213254
[9] Validation-MSE=0.00577476259786636
[10] Train-MSE=0.00577280316501856
[10] Validation-MSE=0.00539038667920977
..........
..........
[91] Train-MSE=0.00177705133100972
[91] Validation-MSE=0.00154715491225943
[92] Train-MSE=0.00177639147732407
[92] Validation-MSE=0.00154592350008897
[93] Train-MSE=0.00177577760769054
[93] Validation-MSE=0.00154474508599378
[94] Train-MSE=0.0017752077546902
[94] Validation-MSE=0.0015436161775142
[95] Train-MSE=0.00177468206966296
[95] Validation-MSE=0.00154253660002723
[96] Train-MSE=0.00177419915562496
[96] Validation-MSE=0.00154150440357625
[97] Train-MSE=0.0017737578949891
[97] Validation-MSE=0.00154051734716631
[98] Train-MSE=0.00177335749613121
[98] Validation-MSE=0.00153957353904843
[99] Train-MSE=0.00177299699280411
[99] Validation-MSE=0.00153867155313492
[100] Train-MSE=0.00177267640829086
[100] Validation-MSE=0.00153781197150238
user system elapsed
21.937 1.914 13.402
</code></pre></div></div>
<p>We can see how mean squared error varies with epochs below.</p>
<p><img src="https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/r/images/loss.png?raw=true" alt="png" /><!--notebook-skip-line--></p>
<h2 id="inference-on-the-network">Inference on the network</h2>
<p>Now we have trained the network. Let’s use it for inference.</p>
<div class="language-r highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">## We extract the state symbols for RNN</span><span class="w">
</span><span class="n">internals</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">symbol</span><span class="o">$</span><span class="n">get.internals</span><span class="p">()</span><span class="w">
</span><span class="n">sym_state</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"RNN_state"</span><span class="p">))</span><span class="w">
</span><span class="n">sym_state_cell</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"RNN_state_cell"</span><span class="p">))</span><span class="w">
</span><span class="n">sym_output</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"loss_output"</span><span class="p">))</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.symbol.Group</span><span class="p">(</span><span class="n">sym_output</span><span class="p">,</span><span class="w"> </span><span class="n">sym_state</span><span class="p">,</span><span class="w"> </span><span class="n">sym_state_cell</span><span class="p">)</span><span class="w">
</span><span class="c1">## We will predict 100 timestamps for 401st sample (first sample from the test samples)</span><span class="w">
</span><span class="n">pred_length</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">100</span><span class="w">
</span><span class="n">predicted</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">numeric</span><span class="p">()</span><span class="w">
</span><span class="c1">## We pass the 400th sample through the network to get the weights and use it for predicting next</span><span class="w">
</span><span class="c1">## 100 time stamps.</span><span class="w">
</span><span class="n">data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">trainX</span><span class="p">[,</span><span class="w"> </span><span class="p">,</span><span class="w"> </span><span class="m">400</span><span class="p">,</span><span class="w"> </span><span class="n">drop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">])</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">trainY</span><span class="p">[,</span><span class="w"> </span><span class="m">400</span><span class="p">,</span><span class="w"> </span><span class="n">drop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">])</span><span class="w">
</span><span class="c1">## We create dataiterators for the input, please note that the label is required to create</span><span class="w">
</span><span class="c1">## iterator and will not be used in the inference. You can use dummy values too in the label.</span><span class="w">
</span><span class="n">infer.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data</span><span class="p">,</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">label</span><span class="p">,</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w">
</span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span><span class="n">infer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.infer.rnn.one</span><span class="p">(</span><span class="n">infer.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer.data</span><span class="p">,</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w">
</span><span class="n">arg.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">arg.params</span><span class="p">,</span><span class="w">
</span><span class="n">aux.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">aux.params</span><span class="p">,</span><span class="w">
</span><span class="n">input.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
</span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">ctx</span><span class="p">)</span><span class="w">
</span><span class="c1">## Once we get the weights for the above time series, we try to predict the next 100 steps for</span><span class="w">
</span><span class="c1">## this time series, which is technically our 401st time series.</span><span class="w">
</span><span class="n">actual</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">trainY</span><span class="p">[,</span><span class="w"> </span><span class="m">401</span><span class="p">]</span><span class="w">
</span><span class="c1">## Now we iterate one by one to generate each of the next timestamp pollution values</span><span class="w">
</span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="n">pred_length</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">trainX</span><span class="p">[,</span><span class="w"> </span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="m">401</span><span class="p">,</span><span class="w"> </span><span class="n">drop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">])</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.array</span><span class="p">(</span><span class="n">trainY</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="w"> </span><span class="m">401</span><span class="p">,</span><span class="w"> </span><span class="n">drop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">])</span><span class="w">
</span><span class="n">infer.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data</span><span class="p">,</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">label</span><span class="p">,</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w">
</span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span><span class="c1">## note that we use rnn state values from previous iterations here</span><span class="w">
</span><span class="n">infer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.infer.rnn.one</span><span class="p">(</span><span class="n">infer.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer.data</span><span class="p">,</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w">
</span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">ctx</span><span class="p">,</span><span class="w">
</span><span class="n">arg.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">arg.params</span><span class="p">,</span><span class="w">
</span><span class="n">aux.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">aux.params</span><span class="p">,</span><span class="w">
</span><span class="n">input.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">rnn.state</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">2</span><span class="p">]],</span><span class="w">
</span><span class="n">rnn.state.cell</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">3</span><span class="p">]]))</span><span class="w">
</span><span class="n">pred</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w">
</span><span class="n">predicted</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">predicted</span><span class="p">,</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">as.array</span><span class="p">(</span><span class="n">pred</span><span class="p">)))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span></code></pre></div></div>
<p>Now predicted contains the predicted 100 values. We use ggplot to plot the actual and predicted values as shown below.</p>
<p><img src="https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/r/images/sample_401.png?raw=true" alt="png" /><!--notebook-skip-line--></p>
<p>We also repeated the above experiments to generate the next 100 samples to 301st time series and we got the following results.</p>
<p><img src="https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/r/images/sample_301.png?raw=true" alt="png" /><!--notebook-skip-line--></p>
<p>The above tutorial is just for demonstration purposes and has not been tuned extensively for accuracy.</p>
<p>For more tutorials on MXNet-R, head on to <a href="/api/r/docs/tutorials">MXNet-R tutorials</a></p>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/master/community#stay-connected">Mailing lists</a></li>
<li><a href="/versions/master/community#github-issues">Github Issues</a></li>
<li><a href="https://github.com/apache/mxnet/projects">Projects</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/master/community">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/master/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>