| <!DOCTYPE html> |
| <html lang=" en"><head> |
| <meta charset="utf-8"> |
| <meta http-equiv="X-UA-Compatible" content="IE=edge"> |
| <meta name="viewport" content="width=device-width, initial-scale=1"> |
| <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 --> |
| <title>Char RNN Model | Apache MXNet</title> |
| <meta name="generator" content="Jekyll v4.0.0" /> |
| <meta property="og:title" content="Char RNN Model" /> |
| <meta property="og:locale" content="en_US" /> |
| <meta name="description" content="A flexible and efficient library for deep learning." /> |
| <meta property="og:description" content="A flexible and efficient library for deep learning." /> |
| <link rel="canonical" href="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/char_rnn_model" /> |
| <meta property="og:url" content="https://mxnet.apache.org/versions/master/api/r/docs/tutorials/char_rnn_model" /> |
| <meta property="og:site_name" content="Apache MXNet" /> |
| <script type="application/ld+json"> |
| {"url":"https://mxnet.apache.org/versions/master/api/r/docs/tutorials/char_rnn_model","headline":"Char RNN Model","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script> |
| <!-- End Jekyll SEO tag --> |
| <script src="https://medium-widget.pixelpoint.io/widget.js"></script> |
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" /> |
| <link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><script> |
| if(!(window.doNotTrack === "1" || navigator.doNotTrack === "1" || navigator.doNotTrack === "yes" || navigator.msDoNotTrack === "1")) { |
| (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ |
| (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), |
| m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) |
| })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); |
| |
| ga('create', 'UA-96378503-1', 'auto'); |
| ga('send', 'pageview'); |
| } |
| </script> |
| |
| <script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script><script src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js" defer></script> |
| <script src="/versions/master/assets/js/globalSearch.js" defer></script> |
| <script src="/versions/master/assets/js/clipboard.js" defer></script> |
| <script src="/versions/master/assets/js/copycode.js" defer></script></head> |
| <body><header class="site-header" role="banner"> |
| |
| <script> |
| $(document).ready(function () { |
| |
| // HEADER OPACITY LOGIC |
| |
| function opacity_header() { |
| var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")" |
| $('.site-header').css("background-color", value) |
| } |
| |
| $(window).scroll(function () { |
| opacity_header() |
| }) |
| opacity_header(); |
| |
| // MENU SELECTOR LOGIC |
| $('.page-link').each( function () { |
| if (window.location.href.includes(this.href)) { |
| $(this).addClass("page-current"); |
| } |
| }); |
| }) |
| </script> |
| <div class="wrapper"> |
| <a class="site-title" rel="author" href="/versions/master/"><img |
| src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a> |
| <nav class="site-nav"> |
| <input type="checkbox" id="nav-trigger" class="nav-trigger"/> |
| <label for="nav-trigger"> |
| <span class="menu-icon"> |
| <svg viewBox="0 0 18 15" width="18px" height="15px"> |
| <path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/> |
| </svg> |
| </span> |
| </label> |
| <div class="gs-search-border"> |
| <div id="gs-search-icon"></div> |
| <form id="global-search-form"> |
| <input id="global-search" type="text" title="Search" placeholder="Search" /> |
| <div id="global-search-dropdown-container"> |
| <button class="gs-current-version btn" type="button" data-toggle="dropdown"> |
| <span id="gs-current-version-label">master</span> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown"> |
| |
| |
| <li class="gs-opt gs-versions active">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| <span id="global-search-close">x</span> |
| </form> |
| </div> |
| <div class="trigger"> |
| <div id="global-search-mobile-border"> |
| <div id="gs-search-icon-mobile"></div> |
| <input id="global-search-mobile" placeholder="Search..." type="text"/> |
| <div id="global-search-dropdown-container-mobile"> |
| <button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown"> |
| <svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"> |
| <path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path> |
| </svg> |
| </button> |
| <ul class="gs-opt-group gs-version-dropdown-mobile"> |
| |
| |
| <li class="gs-opt gs-versions active">master</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.7.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.6.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.5.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.4.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.3.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.2.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.1.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">1.0.0</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.12.1</li> |
| |
| |
| |
| <li class="gs-opt gs-versions">0.11.0</li> |
| |
| |
| </ul> |
| </div> |
| </div> |
| <a class="page-link" href="/versions/master/get_started">Get Started</a> |
| <a class="page-link" href="/versions/master/blog">Blog</a> |
| <a class="page-link" href="/versions/master/features">Features</a> |
| <a class="page-link" href="/versions/master/ecosystem">Ecosystem</a> |
| <a class="page-link" href="/versions/master/api">Docs & Tutorials</a> |
| <a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a> |
| <div class="dropdown"> |
| <span class="dropdown-header">master |
| <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg> |
| </span> |
| <div class="dropdown-content"> |
| |
| |
| <a class="dropdown-option-active" href="/">master</a> |
| |
| |
| |
| <a href="/versions/1.7.0/">1.7.0</a> |
| |
| |
| |
| <a href="/versions/1.6.0/">1.6.0</a> |
| |
| |
| |
| <a href="/versions/1.5.0/">1.5.0</a> |
| |
| |
| |
| <a href="/versions/1.4.1/">1.4.1</a> |
| |
| |
| |
| <a href="/versions/1.3.1/">1.3.1</a> |
| |
| |
| |
| <a href="/versions/1.2.1/">1.2.1</a> |
| |
| |
| |
| <a href="/versions/1.1.0/">1.1.0</a> |
| |
| |
| |
| <a href="/versions/1.0.0/">1.0.0</a> |
| |
| |
| |
| <a href="/versions/0.12.1/">0.12.1</a> |
| |
| |
| |
| <a href="/versions/0.11.0/">0.11.0</a> |
| |
| |
| </div> |
| </div> |
| </div> |
| </nav> |
| </div> |
| </header> |
| <main class="page-content" aria-label="Content"> |
| <script> |
| |
| </script> |
| <article class="post"> |
| |
| <header class="post-header wrapper"> |
| <h1 class="post-title">Char RNN Model</h1> |
| <h3></h3></header> |
| |
| <div class="post-content"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3 docs-side-bar"> |
| |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| |
| |
| |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| |
| <!-- resource-p --> |
| <!-- page --> |
| </ul> |
| </div> |
| <div class="col-9"> |
| <!--- Licensed to the Apache Software Foundation (ASF) under one --> |
| <!--- or more contributor license agreements. See the NOTICE file --> |
| <!--- distributed with this work for additional information --> |
| <!--- regarding copyright ownership. The ASF licenses this file --> |
| <!--- to you under the Apache License, Version 2.0 (the --> |
| <!--- "License"); you may not use this file except in compliance --> |
| <!--- with the License. You may obtain a copy of the License at --> |
| |
| <!--- http://www.apache.org/licenses/LICENSE-2.0 --> |
| |
| <!--- Unless required by applicable law or agreed to in writing, --> |
| <!--- software distributed under the License is distributed on an --> |
| <!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> |
| <!--- KIND, either express or implied. See the License for the --> |
| <!--- specific language governing permissions and limitations --> |
| <!--- under the License. --> |
| |
| <h1 id="character-level-language-model-using-rnn">Character-level Language Model using RNN</h1> |
| |
| <p>This tutorial will demonstrate creating a language model using a character level RNN model using MXNet-R package. You will need the following R packages to run this tutorial -</p> |
| <ul> |
| <li>readr</li> |
| <li>stringr</li> |
| <li>stringi</li> |
| <li>mxnet</li> |
| </ul> |
| |
| <p>We will use the <a href="https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare">tinyshakespeare</a> dataset to build this model.</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">library</span><span class="p">(</span><span class="s2">"readr"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">library</span><span class="p">(</span><span class="s2">"stringr"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">library</span><span class="p">(</span><span class="s2">"stringi"</span><span class="p">)</span><span class="w"> |
| </span><span class="n">library</span><span class="p">(</span><span class="s2">"mxnet"</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <h2 id="preprocess-and-prepare-the-data">Preprocess and prepare the data</h2> |
| |
| <p>Download the data:</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">download.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">dir.create</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span><span class="w"> </span><span class="n">showWarnings</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="o">!</span><span class="n">file.exists</span><span class="p">(</span><span class="n">paste0</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span><span class="s1">'input.txt'</span><span class="p">)))</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">download.file</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="s1">'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt'</span><span class="p">,</span><span class="w"> |
| </span><span class="n">destfile</span><span class="o">=</span><span class="n">paste0</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span><span class="s1">'input.txt'</span><span class="p">),</span><span class="w"> </span><span class="n">method</span><span class="o">=</span><span class="s1">'wget'</span><span class="p">)</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <p>Next we transform the test into feature vectors that is fed into the RNN model. The <code class="highlighter-rouge">make_data</code> function reads the dataset, cleans it of any non-alphanumeric characters, splits it into individual characters and groups it into sequences of length <code class="highlighter-rouge">seq.len</code>.</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">make_data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">path</span><span class="p">,</span><span class="w"> </span><span class="n">seq.len</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">32</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="o">=</span><span class="kc">NULL</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| |
| </span><span class="n">text_vec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">read_file</span><span class="p">(</span><span class="n">file</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">path</span><span class="p">)</span><span class="w"> |
| </span><span class="n">text_vec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">stri_enc_toascii</span><span class="p">(</span><span class="n">str</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">text_vec</span><span class="p">)</span><span class="w"> |
| </span><span class="n">text_vec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">str_replace_all</span><span class="p">(</span><span class="n">string</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">text_vec</span><span class="p">,</span><span class="w"> </span><span class="n">pattern</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"[^[:print:]]"</span><span class="p">,</span><span class="w"> </span><span class="n">replacement</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">""</span><span class="p">)</span><span class="w"> |
| </span><span class="n">text_vec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">strsplit</span><span class="p">(</span><span class="n">text_vec</span><span class="p">,</span><span class="w"> </span><span class="s1">''</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w"> </span><span class="n">unlist</span><span class="w"> |
| |
| </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="nf">is.null</span><span class="p">(</span><span class="n">dic</span><span class="p">))</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">char_keep</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">sort</span><span class="p">(</span><span class="n">unique</span><span class="p">(</span><span class="n">text_vec</span><span class="p">))</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="n">char_keep</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">names</span><span class="p">(</span><span class="n">dic</span><span class="p">)[</span><span class="o">!</span><span class="n">dic</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="m">0</span><span class="p">]</span><span class="w"> |
| |
| </span><span class="c1"># Remove terms not part of dictionary</span><span class="w"> |
| </span><span class="n">text_vec</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">text_vec</span><span class="p">[</span><span class="n">text_vec</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="n">char_keep</span><span class="p">]</span><span class="w"> |
| |
| </span><span class="c1"># Build dictionary</span><span class="w"> |
| </span><span class="n">dic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="nf">length</span><span class="p">(</span><span class="n">char_keep</span><span class="p">)</span><span class="w"> |
| </span><span class="nf">names</span><span class="p">(</span><span class="n">dic</span><span class="p">)</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">char_keep</span><span class="w"> |
| |
| </span><span class="c1"># reverse dictionary</span><span class="w"> |
| </span><span class="n">rev_dic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">names</span><span class="p">(</span><span class="n">dic</span><span class="p">)</span><span class="w"> |
| </span><span class="nf">names</span><span class="p">(</span><span class="n">rev_dic</span><span class="p">)</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">dic</span><span class="w"> |
| |
| </span><span class="c1"># Adjust by -1 to have a 1-lag for labels</span><span class="w"> |
| </span><span class="n">num.seq</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="p">(</span><span class="nf">length</span><span class="p">(</span><span class="n">text_vec</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w"> </span><span class="o">%/%</span><span class="w"> </span><span class="n">seq.len</span><span class="w"> |
| |
| </span><span class="n">features</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">dic</span><span class="p">[</span><span class="n">text_vec</span><span class="p">[</span><span class="m">1</span><span class="o">:</span><span class="p">(</span><span class="n">seq.len</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">num.seq</span><span class="p">)]]</span><span class="w"> |
| </span><span class="n">labels</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">dic</span><span class="p">[</span><span class="n">text_vec</span><span class="p">[</span><span class="m">1</span><span class="o">:</span><span class="p">(</span><span class="n">seq.len</span><span class="o">*</span><span class="n">num.seq</span><span class="p">)</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="p">]]</span><span class="w"> |
| |
| </span><span class="n">features_array</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">features</span><span class="p">,</span><span class="w"> </span><span class="n">dim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">num.seq</span><span class="p">))</span><span class="w"> |
| </span><span class="n">labels_array</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">array</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span><span class="w"> </span><span class="n">dim</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">num.seq</span><span class="p">))</span><span class="w"> |
| |
| </span><span class="n">return</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">features_array</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">features_array</span><span class="p">,</span><span class="w"> </span><span class="n">labels_array</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">labels_array</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dic</span><span class="p">,</span><span class="w"> </span><span class="n">rev_dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rev_dic</span><span class="p">))</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| |
| |
| </span><span class="n">seq.len</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">100</span><span class="w"> |
| </span><span class="n">data_prep</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">make_data</span><span class="p">(</span><span class="n">path</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"input.txt"</span><span class="p">,</span><span class="w"> </span><span class="n">seq.len</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">seq.len</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="o">=</span><span class="kc">NULL</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <p>Fetch the features and labels for training the model, and split the data into training and evaluation in 9:1 ratio.</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">features_array</span><span class="w"> |
| </span><span class="n">Y</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">labels_array</span><span class="w"> |
| </span><span class="n">dic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">dic</span><span class="w"> |
| </span><span class="n">rev_dic</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data_prep</span><span class="o">$</span><span class="n">rev_dic</span><span class="w"> |
| </span><span class="n">vocab</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">dic</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">samples</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">tail</span><span class="p">(</span><span class="nf">dim</span><span class="p">(</span><span class="n">X</span><span class="p">),</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w"> |
| </span><span class="n">train.val.fraction</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">0.9</span><span class="w"> |
| |
| </span><span class="n">X.train.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">X</span><span class="p">[,</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">)]</span><span class="w"> |
| </span><span class="n">X.val.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">X</span><span class="p">[,</span><span class="w"> </span><span class="o">-</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">))]</span><span class="w"> |
| |
| </span><span class="n">X.train.label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">Y</span><span class="p">[,</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">)]</span><span class="w"> |
| </span><span class="n">X.val.label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">Y</span><span class="p">[,</span><span class="w"> </span><span class="o">-</span><span class="p">(</span><span class="m">1</span><span class="o">:</span><span class="nf">as.integer</span><span class="p">(</span><span class="n">samples</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">train.val.fraction</span><span class="p">))]</span><span class="w"> |
| |
| </span><span class="n">train_buckets</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="s2">"100"</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.train.data</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.train.label</span><span class="p">))</span><span class="w"> |
| </span><span class="n">eval_buckets</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="s2">"100"</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.val.data</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">X.val.label</span><span class="p">))</span><span class="w"> |
| |
| </span><span class="n">train_buckets</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train_buckets</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dic</span><span class="p">,</span><span class="w"> </span><span class="n">rev_dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rev_dic</span><span class="p">)</span><span class="w"> |
| </span><span class="n">eval_buckets</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval_buckets</span><span class="p">,</span><span class="w"> </span><span class="n">dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dic</span><span class="p">,</span><span class="w"> </span><span class="n">rev_dic</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rev_dic</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <p>Create iterators for training and evaluation datasets.</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">vocab</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">eval_buckets</span><span class="o">$</span><span class="n">dic</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">batch.size</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="m">32</span><span class="w"> |
| |
| </span><span class="n">train.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.io.bucket.iter</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train_buckets</span><span class="o">$</span><span class="n">buckets</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">,</span><span class="w"> |
| </span><span class="n">data.mask.element</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">eval.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.io.bucket.iter</span><span class="p">(</span><span class="n">buckets</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval_buckets</span><span class="o">$</span><span class="n">buckets</span><span class="p">,</span><span class="w"> </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">batch.size</span><span class="p">,</span><span class="w"> |
| </span><span class="n">data.mask.element</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <h2 id="train-the-model">Train the Model</h2> |
| |
| <p>This model is a multi-layer RNN for sampling from character-level language models. It has a one-to-one model configuration since for each character, we want to predict the next one. For a sequence of length 100, there are also 100 labels, corresponding the same sequence of characters but offset by a position of +1. The parameters output_last_state is set to TRUE in order to access the state of the RNN cells when performing inference.</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">rnn_graph_one_one</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">rnn.graph</span><span class="p">(</span><span class="n">num_rnn_layer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">3</span><span class="p">,</span><span class="w"> |
| </span><span class="n">num_hidden</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">96</span><span class="p">,</span><span class="w"> |
| </span><span class="n">input_size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">vocab</span><span class="p">,</span><span class="w"> |
| </span><span class="n">num_embed</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">64</span><span class="p">,</span><span class="w"> |
| </span><span class="n">num_decode</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">vocab</span><span class="p">,</span><span class="w"> |
| </span><span class="n">dropout</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.2</span><span class="p">,</span><span class="w"> |
| </span><span class="n">ignore_label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> |
| </span><span class="n">cell_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"lstm"</span><span class="p">,</span><span class="w"> |
| </span><span class="n">masking</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">F</span><span class="p">,</span><span class="w"> |
| </span><span class="n">output_last_state</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">T</span><span class="p">,</span><span class="w"> |
| </span><span class="n">loss_output</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"softmax"</span><span class="p">,</span><span class="w"> |
| </span><span class="n">config</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one-to-one"</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">graph.viz</span><span class="p">(</span><span class="n">rnn_graph_one_one</span><span class="p">,</span><span class="w"> </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"graph"</span><span class="p">,</span><span class="w"> </span><span class="n">direction</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"LR"</span><span class="p">,</span><span class="w"> |
| </span><span class="n">graph.height.px</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">180</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="o">=</span><span class="nf">c</span><span class="p">(</span><span class="m">100</span><span class="p">,</span><span class="w"> </span><span class="m">64</span><span class="p">))</span><span class="w"> |
| |
| </span><span class="n">devices</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.cpu</span><span class="p">()</span><span class="w"> |
| |
| </span><span class="n">initializer</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.init.Xavier</span><span class="p">(</span><span class="n">rnd_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"gaussian"</span><span class="p">,</span><span class="w"> </span><span class="n">factor_type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"avg"</span><span class="p">,</span><span class="w"> </span><span class="n">magnitude</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">3</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">optimizer</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.opt.create</span><span class="p">(</span><span class="s2">"adadelta"</span><span class="p">,</span><span class="w"> </span><span class="n">rho</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="n">eps</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1e-5</span><span class="p">,</span><span class="w"> </span><span class="n">wd</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1e-8</span><span class="p">,</span><span class="w"> |
| </span><span class="n">clip_gradient</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">5</span><span class="p">,</span><span class="w"> </span><span class="n">rescale.grad</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="o">/</span><span class="n">batch.size</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">logger</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.metric.logger</span><span class="p">()</span><span class="w"> |
| </span><span class="n">epoch.end.callback</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="n">period</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">logger</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">logger</span><span class="p">)</span><span class="w"> |
| </span><span class="n">batch.end.callback</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.callback.log.train.metric</span><span class="p">(</span><span class="n">period</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">50</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">mx.metric.custom_nd</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="n">feval</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">init</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">()</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="nf">c</span><span class="p">(</span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="m">0</span><span class="p">)</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| </span><span class="n">update</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">,</span><span class="w"> </span><span class="n">state</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">m</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">feval</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> |
| </span><span class="n">state</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">state</span><span class="p">[[</span><span class="m">1</span><span class="p">]]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">state</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> |
| </span><span class="nf">return</span><span class="p">(</span><span class="n">state</span><span class="p">)</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| </span><span class="n">get</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">state</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="nf">list</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="n">value</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">(</span><span class="n">state</span><span class="p">[[</span><span class="m">2</span><span class="p">]]</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="n">state</span><span class="p">[[</span><span class="m">1</span><span class="p">]]))</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| </span><span class="n">ret</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="n">init</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">init</span><span class="p">,</span><span class="w"> </span><span class="n">update</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">update</span><span class="p">,</span><span class="w"> </span><span class="n">get</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">get</span><span class="p">))</span><span class="w"> |
| </span><span class="nf">class</span><span class="p">(</span><span class="n">ret</span><span class="p">)</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="s2">"mx.metric"</span><span class="w"> |
| </span><span class="nf">return</span><span class="p">(</span><span class="n">ret</span><span class="p">)</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| |
| </span><span class="n">mx.metric.Perplexity</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.metric.custom_nd</span><span class="p">(</span><span class="s2">"Perplexity"</span><span class="p">,</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| </span><span class="n">label</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.nd.reshape</span><span class="p">(</span><span class="n">label</span><span class="p">,</span><span class="w"> </span><span class="n">shape</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">)</span><span class="w"> |
| </span><span class="n">label_probs</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">as.array</span><span class="p">(</span><span class="n">mx.nd.choose.element.0index</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span><span class="w"> </span><span class="n">label</span><span class="p">))</span><span class="w"> |
| </span><span class="n">batch</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">label_probs</span><span class="p">)</span><span class="w"> |
| </span><span class="n">NLL</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="o">-</span><span class="nf">sum</span><span class="p">(</span><span class="nf">log</span><span class="p">(</span><span class="n">pmax</span><span class="p">(</span><span class="m">1e-15</span><span class="p">,</span><span class="w"> </span><span class="n">as.array</span><span class="p">(</span><span class="n">label_probs</span><span class="p">))))</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="n">batch</span><span class="w"> |
| </span><span class="n">Perplexity</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">exp</span><span class="p">(</span><span class="n">NLL</span><span class="p">)</span><span class="w"> |
| </span><span class="nf">return</span><span class="p">(</span><span class="n">Perplexity</span><span class="p">)</span><span class="w"> |
| </span><span class="p">})</span><span class="w"> |
| |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.model.buckets</span><span class="p">(</span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">rnn_graph_one_one</span><span class="p">,</span><span class="w"> |
| </span><span class="n">train.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train.data</span><span class="p">,</span><span class="w"> </span><span class="n">eval.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">eval.data</span><span class="p">,</span><span class="w"> |
| </span><span class="n">num.round</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">20</span><span class="p">,</span><span class="w"> </span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">devices</span><span class="p">,</span><span class="w"> </span><span class="n">verbose</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">,</span><span class="w"> |
| </span><span class="n">metric</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mx.metric.Perplexity</span><span class="p">,</span><span class="w"> |
| </span><span class="n">initializer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">initializer</span><span class="p">,</span><span class="w"> </span><span class="n">optimizer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">optimizer</span><span class="p">,</span><span class="w"> |
| </span><span class="n">batch.end.callback</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w"> |
| </span><span class="n">epoch.end.callback</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">epoch.end.callback</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">mx.model.save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">prefix</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one_to_one_seq_model"</span><span class="p">,</span><span class="w"> </span><span class="n">iteration</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">20</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Start training with 1 devices |
| [1] Train-Perplexity=13.7040474322178 |
| [1] Validation-Perplexity=7.94617194460922 |
| [2] Train-Perplexity=6.57039815554525 |
| [2] Validation-Perplexity=6.60806110658011 |
| [3] Train-Perplexity=5.65360504501481 |
| [3] Validation-Perplexity=6.18932770630876 |
| [4] Train-Perplexity=5.32547285727298 |
| [4] Validation-Perplexity=6.02198756798859 |
| [5] Train-Perplexity=5.14373631472579 |
| [5] Validation-Perplexity=5.8095658243407 |
| [6] Train-Perplexity=5.03077673487379 |
| [6] Validation-Perplexity=5.72582993567431 |
| [7] Train-Perplexity=4.94453383291536 |
| [7] Validation-Perplexity=5.6445258528126 |
| [8] Train-Perplexity=4.88635290100261 |
| [8] Validation-Perplexity=5.6730024536433 |
| [9] Train-Perplexity=4.84205646230548 |
| [9] Validation-Perplexity=5.50960780230982 |
| [10] Train-Perplexity=4.80441673535513 |
| [10] Validation-Perplexity=5.57002263750006 |
| [11] Train-Perplexity=4.77763413242626 |
| [11] Validation-Perplexity=5.55152143269169 |
| [12] Train-Perplexity=4.74937775290777 |
| [12] Validation-Perplexity=5.44968305351486 |
| [13] Train-Perplexity=4.72824849541467 |
| [13] Validation-Perplexity=5.50889348298234 |
| [14] Train-Perplexity=4.70980846981694 |
| [14] Validation-Perplexity=5.51473225859859 |
| [15] Train-Perplexity=4.69685776886122 |
| [15] Validation-Perplexity=5.45391985233811 |
| [16] Train-Perplexity=4.67837107034824 |
| [16] Validation-Perplexity=5.46636764997829 |
| [17] Train-Perplexity=4.66866961934873 |
| [17] Validation-Perplexity=5.44267086113492 |
| [18] Train-Perplexity=4.65611469144194 |
| [18] Validation-Perplexity=5.4290169469462 |
| [19] Train-Perplexity=4.64614689879405 |
| [19] Validation-Perplexity=5.44221549833917 |
| [20] Train-Perplexity=4.63764001963654 |
| [20] Validation-Perplexity=5.42114250842862 |
| </code></pre></div></div> |
| |
| <h2 id="inference-on-the-model">Inference on the Model</h2> |
| |
| <p>We now use the saved model to do inference and sample text character by character that will look like the original training data.</p> |
| |
| <div class="language-R highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">set.seed</span><span class="p">(</span><span class="m">0</span><span class="p">)</span><span class="w"> |
| </span><span class="n">model</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.model.load</span><span class="p">(</span><span class="n">prefix</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"one_to_one_seq_model"</span><span class="p">,</span><span class="w"> </span><span class="n">iteration</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">20</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">internals</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">symbol</span><span class="o">$</span><span class="n">get.internals</span><span class="p">()</span><span class="w"> |
| </span><span class="n">sym_state</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"RNN_state"</span><span class="p">))</span><span class="w"> |
| </span><span class="n">sym_state_cell</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"RNN_state_cell"</span><span class="p">))</span><span class="w"> |
| </span><span class="n">sym_output</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">internals</span><span class="o">$</span><span class="n">get.output</span><span class="p">(</span><span class="n">which</span><span class="p">(</span><span class="n">internals</span><span class="o">$</span><span class="n">outputs</span><span class="w"> </span><span class="o">%in%</span><span class="w"> </span><span class="s2">"loss_output"</span><span class="p">))</span><span class="w"> |
| </span><span class="n">symbol</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.symbol.Group</span><span class="p">(</span><span class="n">sym_output</span><span class="p">,</span><span class="w"> </span><span class="n">sym_state</span><span class="p">,</span><span class="w"> </span><span class="n">sym_state_cell</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">infer_raw</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s2">"Thou "</span><span class="p">)</span><span class="w"> |
| </span><span class="n">infer_split</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">dic</span><span class="p">[</span><span class="n">strsplit</span><span class="p">(</span><span class="n">infer_raw</span><span class="p">,</span><span class="w"> </span><span class="s1">''</span><span class="p">)</span><span class="w"> </span><span class="o">%>%</span><span class="w"> </span><span class="n">unlist</span><span class="p">]</span><span class="w"> |
| </span><span class="n">infer_length</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">infer_split</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">infer.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">matrix</span><span class="p">(</span><span class="n">infer_split</span><span class="p">),</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">matrix</span><span class="p">(</span><span class="n">infer_split</span><span class="p">),</span><span class="w"> |
| </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">infer</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.infer.rnn.one</span><span class="p">(</span><span class="n">infer.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer.data</span><span class="p">,</span><span class="w"> |
| </span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w"> |
| </span><span class="n">arg.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">arg.params</span><span class="p">,</span><span class="w"> |
| </span><span class="n">aux.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">aux.params</span><span class="p">,</span><span class="w"> |
| </span><span class="n">input.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w"> |
| </span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">devices</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">pred_prob</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">as.array</span><span class="p">(</span><span class="n">mx.nd.slice.axis</span><span class="p">(</span><span class="w"> |
| </span><span class="n">infer</span><span class="o">$</span><span class="n">loss_output</span><span class="p">,</span><span class="w"> </span><span class="n">axis</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0</span><span class="p">,</span><span class="w"> </span><span class="n">begin</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer_length</span><span class="m">-1</span><span class="p">,</span><span class="w"> </span><span class="n">end</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer_length</span><span class="p">)))</span><span class="w"> |
| </span><span class="n">pred</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">sample</span><span class="p">(</span><span class="nf">length</span><span class="p">(</span><span class="n">pred_prob</span><span class="p">),</span><span class="w"> </span><span class="n">prob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">pred_prob</span><span class="p">,</span><span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="w"> |
| </span><span class="n">predict</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="n">i</span><span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">1</span><span class="o">:</span><span class="m">200</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"> |
| |
| </span><span class="n">infer.data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.io.arrayiter</span><span class="p">(</span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">as.matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">),</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">as.matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">),</span><span class="w"> |
| </span><span class="n">batch.size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">shuffle</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">FALSE</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">infer</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">mx.infer.rnn.one</span><span class="p">(</span><span class="n">infer.data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer.data</span><span class="p">,</span><span class="w"> |
| </span><span class="n">symbol</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">symbol</span><span class="p">,</span><span class="w"> |
| </span><span class="n">arg.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">arg.params</span><span class="p">,</span><span class="w"> |
| </span><span class="n">aux.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model</span><span class="o">$</span><span class="n">aux.params</span><span class="p">,</span><span class="w"> |
| </span><span class="n">input.params</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">rnn.state</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">2</span><span class="p">]],</span><span class="w"> |
| </span><span class="n">rnn.state.cell</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">infer</span><span class="p">[[</span><span class="m">3</span><span class="p">]]),</span><span class="w"> |
| </span><span class="n">ctx</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">devices</span><span class="p">)</span><span class="w"> |
| |
| </span><span class="n">pred_prob</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">as.numeric</span><span class="p">(</span><span class="n">as.array</span><span class="p">(</span><span class="n">infer</span><span class="o">$</span><span class="n">loss_output</span><span class="p">))</span><span class="w"> |
| </span><span class="n">pred</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">sample</span><span class="p">(</span><span class="nf">length</span><span class="p">(</span><span class="n">pred_prob</span><span class="p">),</span><span class="w"> </span><span class="n">prob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">pred_prob</span><span class="p">,</span><span class="w"> </span><span class="n">size</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">1</span><span class="p">,</span><span class="w"> </span><span class="n">replace</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nb">T</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="m">1</span><span class="w"> |
| </span><span class="n">predict</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span><span class="w"> </span><span class="n">pred</span><span class="p">)</span><span class="w"> |
| </span><span class="p">}</span><span class="w"> |
| |
| </span><span class="n">predict_txt</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">paste0</span><span class="p">(</span><span class="n">rev_dic</span><span class="p">[</span><span class="nf">as.character</span><span class="p">(</span><span class="n">predict</span><span class="p">)],</span><span class="w"> </span><span class="n">collapse</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">""</span><span class="p">)</span><span class="w"> |
| </span><span class="n">predict_txt_tot</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">paste0</span><span class="p">(</span><span class="n">infer_raw</span><span class="p">,</span><span class="w"> </span><span class="n">predict_txt</span><span class="p">,</span><span class="w"> </span><span class="n">collapse</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">""</span><span class="p">)</span><span class="w"> |
| </span><span class="n">print</span><span class="p">(</span><span class="n">predict_txt_tot</span><span class="p">)</span><span class="w"> |
| </span></code></pre></div></div> |
| |
| <div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[1] "Thou NAknowledge thee my Comfort and his late she.FRIAR LAURENCE:Nothing a groats waterd forth. The lend he thank that;When she I am brother draw London: and not hear that know.BENVOLIO:How along, makes your " |
| </code></pre></div></div> |
| |
| <!-- INSERT SOURCE DOWNLOAD BUTTONS --> |
| </div> |
| </div> |
| |
| </div> |
| </div> |
| |
| </article> |
| |
| </main><footer class="site-footer h-card"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-4"> |
| <h4 class="footer-category-title">Resources</h4> |
| <ul class="contact-list"> |
| <li><a href="/versions/master/community#stay-connected">Mailing lists</a></li> |
| <li><a href="https://discuss.mxnet.io">MXNet Discuss forum</a></li> |
| <li><a href="/versions/master/community#github-issues">Github Issues</a></li> |
| <li><a href="https://github.com/apache/incubator-mxnet/projects">Projects</a></li> |
| <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> |
| <li><a href="/versions/master/community">Contribute To MXNet</a></li> |
| |
| </ul> |
| </div> |
| |
| <div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul> |
| </div> |
| |
| <div class="col-4 footer-text"> |
| <p>A flexible and efficient library for deep learning.</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| <footer class="site-footer2"> |
| <div class="wrapper"> |
| <div class="row"> |
| <div class="col-3"> |
| <img src="/versions/master/assets/img/apache_incubator_logo.png" class="footer-logo col-2"> |
| </div> |
| <div class="footer-bottom-warning col-9"> |
| <p>Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <span |
| style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required |
| of all newly accepted projects until a further review indicates that the infrastructure, |
| communications, and decision making process have stabilized in a manner consistent with other |
| successful ASF projects. While incubation status is not necessarily a reflection of the completeness |
| or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. |
| </p><p>"Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache |
| feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the |
| Apache Software Foundation."</p> |
| </div> |
| </div> |
| </div> |
| </footer> |
| |
| |
| |
| |
| </body> |
| |
| </html> |