blob: 1b044767ae2c149aeed68a5047e51e91fa23200f [file] [log] [blame]
<!DOCTYPE html>
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Char RNN Model | Apache MXNet</title>
<meta name="generator" content="Jekyll v4.0.0" />
<meta property="og:title" content="Char RNN Model" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/char_rnn_model" />
<meta property="og:url" content="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/char_rnn_model" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/master/api/r/docs/tutorials/char_rnn_model","headline":"Char RNN Model","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<script src="https://medium-widget.pixelpoint.io/widget.js"></script>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><script>
if(!(window.doNotTrack === "1" || navigator.doNotTrack === "1" || navigator.doNotTrack === "yes" || navigator.msDoNotTrack === "1")) {
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
}
</script>
<script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script><script src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js" defer></script>
<script src="/versions/master/assets/js/globalSearch.js" defer></script>
<script src="/versions/master/assets/js/clipboard.js" defer></script>
<script src="/versions/master/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/master/"><img
src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">master</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/master/get_started">Get Started</a>
<a class="page-link" href="/versions/master/blog">Blog</a>
<a class="page-link" href="/versions/master/features">Features</a>
<a class="page-link" href="/versions/master/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/master/api">Docs & Tutorials</a>
<a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a>
<div class="dropdown">
<span class="dropdown-header">master
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option-active" href="/">master</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Char RNN Model</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="character-level-language-model-using-rnn">Character-level Language Model using RNN</h1>
<p>This tutorial will demonstrate creating a language model using a character level RNN model using MXNet-R package. You will need the following R packages to run this tutorial -</p>
<ul>
<li>readr</li>
<li>stringr</li>
<li>stringi</li>
<li>mxnet</li>
</ul>
<p>We will use the <a href="https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare">tinyshakespeare</a> dataset to build this model.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">library</span><span class="p">(</span><span class="s2">"readr"</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"stringr"</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"stringi"</span><span class="p">)</span><span class="w">
</span><span class="n">library</span><span class="p">(</span><span class="s2">"mxnet"</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<h2 id="preprocess-and-prepare-the-data">Preprocess and prepare the data</h2>
<p>Download the data:</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">download.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">dir.create</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span><span class="w"> </span><span class="n">showWarnings</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="o">!</span><span class="n">file.exists</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span><span class="s1">'input.txt'</span><span class="p">)))</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">download.file</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="s1">'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt'</span><span class="p">,</span><span class="w">
</span><span class="n">destfile</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span><span class="s1">'input.txt'</span><span class="p">),</span><span class="w"> </span><span class="n">method</span><span class="o">=</span><span class="s1">'wget'</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="p">}</span><span class="w">
</span></code></pre></div></div>
<p>Next we transform the test into feature vectors that is fed into the RNN model. The <code class="highlighter-rouge">make_data</code> function reads the dataset, cleans it of any non-alphanumeric characters, splits it into individual characters and groups it into sequences of length <code class="highlighter-rouge">seq.len</code>.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">make_data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">path</span><span class="p">,</span><span class="w"> </span><span class="n">seq.len</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">32</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="o">=</span><span class="kc">NULL</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">text_vec</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">read_file</span><span class="p">(</span><span class="n">file</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">path</span><span class="p">)</span><span class="w">
</span><span class="n">text_vec</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">stri_enc_toascii</span><span class="p">(</span><span class="n">str</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">text_vec</span><span class="p">)</span><span class="w">
</span><span class="n">text_vec</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">str_replace_all</span><span class="p">(</span><span class="n">string</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">text_vec</span><span class="p">,</span><span class="w"> </span><span class="n">pattern</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"[^[:print:]]"</span><span class="p">,</span><span class="w"> </span><span class="n">replacement</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">""</span><span class="p">)</span><span class="w">
</span><span class="n">text_vec</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">strsplit</span><span class="p">(</span><span class="n">text_vec</span><span class="p">,</span><span class="w"> </span><span class="s1">''</span><span class="p">)</span><span class="w"> </span><span class="o">%&gt;%</span><span class="w"> </span><span class="n">unlist</span><span class="w">
</span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="nf">is.null</span><span class="p">(</span><span class="n">dic</span><span class="p">))</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">char_keep</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">sort</span><span class="p">(</span><span class="n">unique</span><span class="p">(</span><span class="n">text_vec</span><span class="p">))</span><span class="w">
</span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="n">char_keep</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">names</span><span class="p">(</span><span class="n">dic</span><span class="p">)[</span><span class="o">!</span><span class="n">dic</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="m">0</span><span class="p">]</span><span class="w">
</span><span class="c1"># Remove terms not part of dictionary</span><span class="w">
</span><span class="n">text_vec</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">text_vec</span><span class="p">[</span><span class="n">text_vec</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="n">char_keep</span><span class="p">]</span><span class="w">
</span><span class="c1"># Build dictionary</span><span class="w">
</span><span class="n">dic</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="nf">length</span><span class="p">(</span><span class="n">char_keep</span><span class="p">)</span><span class="w">
</span><span class="nf">names</span><span class="p">(</span><span class="n">dic</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">char_keep</span><span class="w">
</span><span class="c1"># reverse dictionary</span><span class="w">
</span><span class="n">rev_dic</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">names</span><span class="p">(</span><span class="n">dic</span><span class="p">)</span><span class="w">
</span><span class="nf">names</span><span class="p">(</span><span class="n">rev_dic</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">dic</span><span class="w">
</span><span class="c1"># Adjust by -1 to have a 1-lag for labels</span><span class="w">
</span><span class="n">num.seq</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="p">(</span><span class="nf">length</span><span class="p">(</span><span class="n">text_vec</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w"> </span><span class="o">%/%</span><span class="w"> </span><span class="n">seq.len</span><span class="w">
</span><span class="n">features</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">dic</span><span class="p">[</span><span class="n">text_vec</span><span class="p">[</span><span class="m">1</span><span class="o">:</span><span class="p">(</span><span class="n">seq.len</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">num.seq</span><span class="p">)]]</span><span class="w">
</span><span class="n">labels</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">dic</span><span class="p">[</span><span class="n">text_vec</span><span class="p">[</span><span class="m">1</span><span class="o">:</span><span class="p">(</span><span class="n">seq.len</span><span class="o">*</span><span class="n">num.seq</span><span class="p">)</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="p">]]</span><span class="w">
</span><span class="n">features_array</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">dim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">num.seq</span><span class="p">))</span><span class="w">
</span><span class="n">labels_array</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span><span class="w"> </span><span class="n">dim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">num.seq</span><span class="p">))</span><span class="w">
</span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">features_array</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">features_array</span><span class="p">,</span><span class="w"> </span><span class="n">labels_array</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">labels_array</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dic</span><span class="p">,</span><span class="w"> </span><span class="n">rev_dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rev_dic</span><span class="p">))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">seq.len</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">100</span><span class="w">
</span><span class="n">data_prep</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">make_data</span><span class="p">(</span><span class="n">path</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"input.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">seq.len</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="o">=</span><span class="kc">NULL</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>Fetch the features and labels for training the model, and split the data into training and evaluation in 9:1 ratio.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">features_array</span><span class="w">
</span><span class="n">Y</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">labels_array</span><span class="w">
</span><span class="n">dic</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">dic</span><span class="w">
</span><span class="n">rev_dic</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">rev_dic</span><span class="w">
</span><span class="n">vocab</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">dic</span><span class="p">)</span><span class="w">
</span><span class="n">samples</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">tail</span><span class="p">(</span><span class="nf">dim</span><span class="p">(</span><span class="n">X</span><span class="p">),</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w">
</span><span class="n">train.val.fraction</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">0.9</span><span class="w">
</span><span class="n">X.train.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">X</span><span class="p">[,</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">)]</span><span class="w">
</span><span class="n">X.val.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">X</span><span class="p">[,</span><span class="w"> </span><span class="o">-</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">))]</span><span class="w">
</span><span class="n">X.train.label</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">Y</span><span class="p">[,</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">)]</span><span class="w">
</span><span class="n">X.val.label</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">Y</span><span class="p">[,</span><span class="w"> </span><span class="o">-</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">))]</span><span class="w">
</span><span class="n">train_buckets</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="s2">"100"</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.train.data</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.train.label</span><span class="p">))</span><span class="w">
</span><span class="n">eval_buckets</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="s2">"100"</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.val.data</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.val.label</span><span class="p">))</span><span class="w">
</span><span class="n">train_buckets</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train_buckets</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dic</span><span class="p">,</span><span class="w"> </span><span class="n">rev_dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rev_dic</span><span class="p">)</span><span class="w">
</span><span class="n">eval_buckets</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval_buckets</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dic</span><span class="p">,</span><span class="w"> </span><span class="n">rev_dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rev_dic</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<p>Create iterators for training and evaluation datasets.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">vocab</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">eval_buckets</span><span class="o">$</span><span class="n">dic</span><span class="p">)</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="m">32</span><span class="w">
</span><span class="n">train.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.bucket.iter</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train_buckets</span><span class="o">$</span><span class="n">buckets</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">,</span><span class="w">
</span><span class="n">data.mask.element</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">)</span><span class="w">
</span><span class="n">eval.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.bucket.iter</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval_buckets</span><span class="o">$</span><span class="n">buckets</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">,</span><span class="w">
</span><span class="n">data.mask.element</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<h2 id="train-the-model">Train the Model</h2>
<p>This model is a multi-layer RNN for sampling from character-level language models. It has a one-to-one model configuration since for each character, we want to predict the next one. For a sequence of length 100, there are also 100 labels, corresponding the same sequence of characters but offset by a position of +1. The parameters output_last_state is set to TRUE in order to access the state of the RNN cells when performing inference.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">rnn_graph_one_one</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">rnn.graph</span><span class="p">(</span><span class="n">num_rnn_layer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">3</span><span class="p">,</span><span class="w">
</span><span class="n">num_hidden</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">96</span><span class="p">,</span><span class="w">
</span><span class="n">input_size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">vocab</span><span class="p">,</span><span class="w">
</span><span class="n">num_embed</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">64</span><span class="p">,</span><span class="w">
</span><span class="n">num_decode</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">vocab</span><span class="p">,</span><span class="w">
</span><span class="n">dropout</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.2</span><span class="p">,</span><span class="w">
</span><span class="n">ignore_label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w">
</span><span class="n">cell_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"lstm"</span><span class="p">,</span><span class="w">
</span><span class="n">masking</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">,</span><span class="w">
</span><span class="n">output_last_state</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">T</span><span class="p">,</span><span class="w">
</span><span class="n">loss_output</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"softmax"</span><span class="p">,</span><span class="w">
</span><span class="n">config</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one-to-one"</span><span class="p">)</span><span class="w">
</span><span class="n">graph.viz</span><span class="p">(</span><span class="n">rnn_graph_one_one</span><span class="p">,</span><span class="w"> </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"graph"</span><span class="p">,</span><span class="w"> </span><span class="n">direction</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"LR"</span><span class="p">,</span><span class="w">
</span><span class="n">graph.height.px</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">180</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">100</span><span class="p">,</span><span class="w"> </span><span class="m">64</span><span class="p">))</span><span class="w">
</span><span class="n">devices</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.cpu</span><span class="p">()</span><span class="w">
</span><span class="n">initializer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.init.Xavier</span><span class="p">(</span><span class="n">rnd_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"gaussian"</span><span class="p">,</span><span class="w"> </span><span class="n">factor_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"avg"</span><span class="p">,</span><span class="w"> </span><span class="n">magnitude</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">3</span><span class="p">)</span><span class="w">
</span><span class="n">optimizer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.opt.create</span><span class="p">(</span><span class="s2">"adadelta"</span><span class="p">,</span><span class="w"> </span><span class="n">rho</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">eps</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1e-5</span><span class="p">,</span><span class="w"> </span><span class="n">wd</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1e-8</span><span class="p">,</span><span class="w">
</span><span class="n">clip_gradient</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">5</span><span class="p">,</span><span class="w"> </span><span class="n">rescale.grad</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="o">/</span><span class="n">batch.size</span><span class="p">)</span><span class="w">
</span><span class="n">logger</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.metric.logger</span><span class="p">()</span><span class="w">
</span><span class="n">epoch.end.callback</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="n">period</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">logger</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">logger</span><span class="p">)</span><span class="w">
</span><span class="n">batch.end.callback</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="n">period</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">50</span><span class="p">)</span><span class="w">
</span><span class="n">mx.metric.custom_nd</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="n">feval</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">init</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">()</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="nf">c</span><span class="p">(</span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">update</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">,</span><span class="w"> </span><span class="n">state</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">m</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">feval</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w">
</span><span class="n">state</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">state</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">state</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w">
</span><span class="nf">return</span><span class="p">(</span><span class="n">state</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">get</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">state</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="nf">list</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="n">value</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">(</span><span class="n">state</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="n">state</span><span class="p">[[</span><span class="m">1</span><span class="p">]]))</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">ret</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">init</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">init</span><span class="p">,</span><span class="w"> </span><span class="n">update</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">update</span><span class="p">,</span><span class="w"> </span><span class="n">get</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">get</span><span class="p">))</span><span class="w">
</span><span class="nf">class</span><span class="p">(</span><span class="n">ret</span><span class="p">)</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="s2">"mx.metric"</span><span class="w">
</span><span class="nf">return</span><span class="p">(</span><span class="n">ret</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">mx.metric.Perplexity</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.metric.custom_nd</span><span class="p">(</span><span class="s2">"Perplexity"</span><span class="p">,</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">label</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.nd.reshape</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">)</span><span class="w">
</span><span class="n">label_probs</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">as.array</span><span class="p">(</span><span class="n">mx.nd.choose.element.0index</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="p">))</span><span class="w">
</span><span class="n">batch</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">label_probs</span><span class="p">)</span><span class="w">
</span><span class="n">NLL</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="o">-</span><span class="nf">sum</span><span class="p">(</span><span class="nf">log</span><span class="p">(</span><span class="n">pmax</span><span class="p">(</span><span class="m">1e-15</span><span class="p">,</span><span class="w"> </span><span class="n">as.array</span><span class="p">(</span><span class="n">label_probs</span><span class="p">))))</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="n">batch</span><span class="w">
</span><span class="n">Perplexity</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">exp</span><span class="p">(</span><span class="n">NLL</span><span class="p">)</span><span class="w">
</span><span class="nf">return</span><span class="p">(</span><span class="n">Perplexity</span><span class="p">)</span><span class="w">
</span><span class="p">})</span><span class="w">
</span><span class="n">model</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.model.buckets</span><span class="p">(</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rnn_graph_one_one</span><span class="p">,</span><span class="w">
</span><span class="n">train.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train.data</span><span class="p">,</span><span class="w"> </span><span class="n">eval.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval.data</span><span class="p">,</span><span class="w">
</span><span class="n">num.round</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">20</span><span class="p">,</span><span class="w"> </span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">devices</span><span class="p">,</span><span class="w"> </span><span class="n">verbose</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">,</span><span class="w">
</span><span class="n">metric</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.metric.Perplexity</span><span class="p">,</span><span class="w">
</span><span class="n">initializer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">initializer</span><span class="p">,</span><span class="w"> </span><span class="n">optimizer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">optimizer</span><span class="p">,</span><span class="w">
</span><span class="n">batch.end.callback</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
</span><span class="n">epoch.end.callback</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">epoch.end.callback</span><span class="p">)</span><span class="w">
</span><span class="n">mx.model.save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">prefix</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one_to_one_seq_model"</span><span class="p">,</span><span class="w"> </span><span class="n">iteration</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">20</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Start training with 1 devices
[1] Train-Perplexity=13.7040474322178
[1] Validation-Perplexity=7.94617194460922
[2] Train-Perplexity=6.57039815554525
[2] Validation-Perplexity=6.60806110658011
[3] Train-Perplexity=5.65360504501481
[3] Validation-Perplexity=6.18932770630876
[4] Train-Perplexity=5.32547285727298
[4] Validation-Perplexity=6.02198756798859
[5] Train-Perplexity=5.14373631472579
[5] Validation-Perplexity=5.8095658243407
[6] Train-Perplexity=5.03077673487379
[6] Validation-Perplexity=5.72582993567431
[7] Train-Perplexity=4.94453383291536
[7] Validation-Perplexity=5.6445258528126
[8] Train-Perplexity=4.88635290100261
[8] Validation-Perplexity=5.6730024536433
[9] Train-Perplexity=4.84205646230548
[9] Validation-Perplexity=5.50960780230982
[10] Train-Perplexity=4.80441673535513
[10] Validation-Perplexity=5.57002263750006
[11] Train-Perplexity=4.77763413242626
[11] Validation-Perplexity=5.55152143269169
[12] Train-Perplexity=4.74937775290777
[12] Validation-Perplexity=5.44968305351486
[13] Train-Perplexity=4.72824849541467
[13] Validation-Perplexity=5.50889348298234
[14] Train-Perplexity=4.70980846981694
[14] Validation-Perplexity=5.51473225859859
[15] Train-Perplexity=4.69685776886122
[15] Validation-Perplexity=5.45391985233811
[16] Train-Perplexity=4.67837107034824
[16] Validation-Perplexity=5.46636764997829
[17] Train-Perplexity=4.66866961934873
[17] Validation-Perplexity=5.44267086113492
[18] Train-Perplexity=4.65611469144194
[18] Validation-Perplexity=5.4290169469462
[19] Train-Perplexity=4.64614689879405
[19] Validation-Perplexity=5.44221549833917
[20] Train-Perplexity=4.63764001963654
[20] Validation-Perplexity=5.42114250842862
</code></pre></div></div>
<h2 id="inference-on-the-model">Inference on the Model</h2>
<p>We now use the saved model to do inference and sample text character by character that will look like the original training data.</p>
<div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w">
</span><span class="n">model</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.model.load</span><span class="p">(</span><span class="n">prefix</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one_to_one_seq_model"</span><span class="p">,</span><span class="w"> </span><span class="n">iteration</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">20</span><span class="p">)</span><span class="w">
</span><span class="n">internals</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">symbol</span><span class="o">$</span><span class="n">get.internals</span><span class="p">()</span><span class="w">
</span><span class="n">sym_state</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"RNN_state"</span><span class="p">))</span><span class="w">
</span><span class="n">sym_state_cell</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"RNN_state_cell"</span><span class="p">))</span><span class="w">
</span><span class="n">sym_output</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"loss_output"</span><span class="p">))</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.symbol.Group</span><span class="p">(</span><span class="n">sym_output</span><span class="p">,</span><span class="w"> </span><span class="n">sym_state</span><span class="p">,</span><span class="w"> </span><span class="n">sym_state_cell</span><span class="p">)</span><span class="w">
</span><span class="n">infer_raw</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s2">"Thou "</span><span class="p">)</span><span class="w">
</span><span class="n">infer_split</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">dic</span><span class="p">[</span><span class="n">strsplit</span><span class="p">(</span><span class="n">infer_raw</span><span class="p">,</span><span class="w"> </span><span class="s1">''</span><span class="p">)</span><span class="w"> </span><span class="o">%&gt;%</span><span class="w"> </span><span class="n">unlist</span><span class="p">]</span><span class="w">
</span><span class="n">infer_length</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">infer_split</span><span class="p">)</span><span class="w">
</span><span class="n">infer.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">matrix</span><span class="p">(</span><span class="n">infer_split</span><span class="p">),</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">matrix</span><span class="p">(</span><span class="n">infer_split</span><span class="p">),</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span><span class="n">infer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.infer.rnn.one</span><span class="p">(</span><span class="n">infer.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer.data</span><span class="p">,</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w">
</span><span class="n">arg.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">arg.params</span><span class="p">,</span><span class="w">
</span><span class="n">aux.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">aux.params</span><span class="p">,</span><span class="w">
</span><span class="n">input.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
</span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">devices</span><span class="p">)</span><span class="w">
</span><span class="n">pred_prob</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">as.array</span><span class="p">(</span><span class="n">mx.nd.slice.axis</span><span class="p">(</span><span class="w">
</span><span class="n">infer</span><span class="o">$</span><span class="n">loss_output</span><span class="p">,</span><span class="w"> </span><span class="n">axis</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">begin</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer_length</span><span class="m">-1</span><span class="p">,</span><span class="w"> </span><span class="n">end</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer_length</span><span class="p">)))</span><span class="w">
</span><span class="n">pred</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">sample</span><span class="p">(</span><span class="nf">length</span><span class="p">(</span><span class="n">pred_prob</span><span class="p">),</span><span class="w"> </span><span class="n">prob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">pred_prob</span><span class="p">,</span><span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="w">
</span><span class="n">predict</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w">
</span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="m">200</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
</span><span class="n">infer.data</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">as.matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">),</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">as.matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">),</span><span class="w">
</span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w">
</span><span class="n">infer</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">mx.infer.rnn.one</span><span class="p">(</span><span class="n">infer.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer.data</span><span class="p">,</span><span class="w">
</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w">
</span><span class="n">arg.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">arg.params</span><span class="p">,</span><span class="w">
</span><span class="n">aux.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">aux.params</span><span class="p">,</span><span class="w">
</span><span class="n">input.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">rnn.state</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">2</span><span class="p">]],</span><span class="w">
</span><span class="n">rnn.state.cell</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">3</span><span class="p">]]),</span><span class="w">
</span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">devices</span><span class="p">)</span><span class="w">
</span><span class="n">pred_prob</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">as.array</span><span class="p">(</span><span class="n">infer</span><span class="o">$</span><span class="n">loss_output</span><span class="p">))</span><span class="w">
</span><span class="n">pred</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">sample</span><span class="p">(</span><span class="nf">length</span><span class="p">(</span><span class="n">pred_prob</span><span class="p">),</span><span class="w"> </span><span class="n">prob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">pred_prob</span><span class="p">,</span><span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">replace</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">T</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="w">
</span><span class="n">predict</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="n">predict_txt</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">paste0</span><span class="p">(</span><span class="n">rev_dic</span><span class="p">[</span><span class="nf">as.character</span><span class="p">(</span><span class="n">predict</span><span class="p">)],</span><span class="w"> </span><span class="n">collapse</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">""</span><span class="p">)</span><span class="w">
</span><span class="n">predict_txt_tot</span><span class="w"> </span><span class="o">&lt;-</span><span class="w"> </span><span class="n">paste0</span><span class="p">(</span><span class="n">infer_raw</span><span class="p">,</span><span class="w"> </span><span class="n">predict_txt</span><span class="p">,</span><span class="w"> </span><span class="n">collapse</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">""</span><span class="p">)</span><span class="w">
</span><span class="n">print</span><span class="p">(</span><span class="n">predict_txt_tot</span><span class="p">)</span><span class="w">
</span></code></pre></div></div>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[1] "Thou NAknowledge thee my Comfort and his late she.FRIAR LAURENCE:Nothing a groats waterd forth. The lend he thank that;When she I am brother draw London: and not hear that know.BENVOLIO:How along, makes your "
</code></pre></div></div>
<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/master/community#stay-connected">Mailing lists</a></li>
<li><a href="https://discuss.mxnet.io">MXNet Discuss forum</a></li>
<li><a href="/versions/master/community#github-issues">Github Issues</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/projects">Projects</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="/versions/master/community">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/master/assets/img/apache_incubator_logo.png" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
<p>Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <span
style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required
of all newly accepted projects until a further review indicates that the infrastructure,
communications, and decision making process have stabilized in a manner consistent with other
successful ASF projects. While incubation status is not necessarily a reflection of the completeness
or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p><p>"Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>