blob: 9aad0ae5d3df69c06fb50559bb59ebb512050891 [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/master/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Module API | Apache MXNet</title>
<meta name="generator" content="Jekyll v4.0.0" />
<meta property="og:title" content="Module API" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/master/api/clojure/docs/tutorials/module" />
<meta property="og:url" content="https://mxnet.apache.org/versions/master/api/clojure/docs/tutorials/module" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/master/api/clojure/docs/tutorials/module","headline":"Module API","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/master/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/master/assets/js/docsearch.min.js"></script><script src="/versions/master/assets/js/globalSearch.js" defer></script>
<script src="/versions/master/assets/js/clipboard.js" defer></script>
<script src="/versions/master/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/master/"><img
src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">master</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/master/get_started">Get Started</a>
<a class="page-link" href="/versions/master/features">Features</a>
<a class="page-link" href="/versions/master/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/master/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/master/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/master/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">master
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option-active" href="/">master</a>
<a href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Module API</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="module-api">Module API</h1>
<p>The module API provides an intermediate and high-level interface for performing computation with neural networks in MXNet. Module wraps a Symbol and one or more Executors. It has both a high level and intermediate level API.</p>
<p>Topics:</p>
<ul>
<li><a href="#prepare-the-data">Prepare the Data</a></li>
<li><a href="#list-key-value-pairs">List Key-Value Pairs</a></li>
<li><a href="#preparing-a-module-for-computation">Preparing a Module for Computation</a></li>
<li><a href="#training-and-predicting">Training and Predicting</a></li>
<li><a href="#saving-and-loading">Saving and Loading</a></li>
<li><a href="/versions/master/api/clojure/docs/api">Clojure API Reference</a></li>
</ul>
<p>To follow along with this documentation, you can use this namespace to with the needed requires:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nf">ns</span><span class="w"> </span><span class="n">docs.module</span><span class="w">
</span><span class="p">(</span><span class="no">:require</span><span class="w"> </span><span class="p">[</span><span class="n">clojure.java.io</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">io</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">clojure.java.shell</span><span class="w"> </span><span class="no">:refer</span><span class="w"> </span><span class="p">[</span><span class="n">sh</span><span class="p">]]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.eval-metric</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">eval-metric</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.io</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">mx-io</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.module</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">m</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.symbol</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">sym</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.ndarray</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">ndarray</span><span class="p">]))</span><span class="w">
</span></code></pre></div></div>
<h2 id="prepare-the-data">Prepare the Data</h2>
<p>In this example, we are going to use the MNIST data set. If you have cloned the MXNet repo and <code class="highlighter-rouge">cd contrib/clojure-package</code>, we can run some helper scripts to download the data for us.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"data/"</span><span class="p">)</span><span class="w">
</span><span class="p">(</span><span class="nb">when-not</span><span class="w"> </span><span class="p">(</span><span class="nf">.exists</span><span class="w"> </span><span class="p">(</span><span class="nf">io/file</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"train-images-idx3-ubyte"</span><span class="p">)))</span><span class="w">
</span><span class="p">(</span><span class="nf">sh</span><span class="w"> </span><span class="s">"../../scripts/get_mnist_data.sh"</span><span class="p">))</span><span class="w">
</span></code></pre></div></div>
<p>MXNet provides function in the <code class="highlighter-rouge">io</code> namespace to load the MNIST datasets into training and test data iterators that we can use with our module.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/mnist-iter</span><span class="w"> </span><span class="p">{</span><span class="no">:image</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"train-images-idx3-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:label</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"train-labels-idx1-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:label-name</span><span class="w"> </span><span class="s">"softmax_label"</span><span class="w">
</span><span class="no">:input-shape</span><span class="w"> </span><span class="p">[</span><span class="mi">784</span><span class="p">]</span><span class="w">
</span><span class="no">:batch-size</span><span class="w"> </span><span class="mi">10</span><span class="w">
</span><span class="no">:shuffle</span><span class="w"> </span><span class="n">true</span><span class="w">
</span><span class="no">:flat</span><span class="w"> </span><span class="n">true</span><span class="w">
</span><span class="no">:silent</span><span class="w"> </span><span class="n">false</span><span class="w">
</span><span class="no">:seed</span><span class="w"> </span><span class="mi">10</span><span class="p">}))</span><span class="w">
</span><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/mnist-iter</span><span class="w"> </span><span class="p">{</span><span class="no">:image</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"t10k-images-idx3-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:label</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"t10k-labels-idx1-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:input-shape</span><span class="w"> </span><span class="p">[</span><span class="mi">784</span><span class="p">]</span><span class="w">
</span><span class="no">:batch-size</span><span class="w"> </span><span class="mi">10</span><span class="w">
</span><span class="no">:flat</span><span class="w"> </span><span class="n">true</span><span class="w">
</span><span class="no">:silent</span><span class="w"> </span><span class="n">false</span><span class="p">}))</span><span class="w">
</span></code></pre></div></div>
<h2 id="preparing-a-module-for-computation">Preparing a Module for Computation</h2>
<p>To construct a module, we need to have a symbol as input. This symbol takes input data in the first layer and then has subsequent layers of fully connected and relu activation layers, ending up in a softmax layer for output.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">data</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/variable</span><span class="w"> </span><span class="s">"data"</span><span class="p">)</span><span class="w">
</span><span class="n">fc1</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">128</span><span class="p">})</span><span class="w">
</span><span class="n">act1</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">fc1</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="n">fc2</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">act1</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">64</span><span class="p">})</span><span class="w">
</span><span class="n">act2</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">fc2</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="n">fc3</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc3"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">act2</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">10</span><span class="p">})</span><span class="w">
</span><span class="n">out</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/softmax-output</span><span class="w"> </span><span class="s">"softmax"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">fc3</span><span class="p">})]</span><span class="w">
</span><span class="n">out</span><span class="p">)</span><span class="w">
</span><span class="c1">;=&gt;#object[org.apache.mxnet.Symbol 0x1f43a406 "org.apache.mxnet.Symbol@1f43a406"]</span><span class="w">
</span></code></pre></div></div>
<p>You can also write this with the <code class="highlighter-rouge">as-&gt;</code> threading macro.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="p">(</span><span class="nf">as-&gt;</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/variable</span><span class="w"> </span><span class="s">"data"</span><span class="p">)</span><span class="w"> </span><span class="n">data</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">128</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">64</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc3"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">10</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/softmax-output</span><span class="w"> </span><span class="s">"softmax"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="p">})))</span><span class="w">
</span><span class="c1">;=&gt; #'tutorial.module/out</span><span class="w">
</span></code></pre></div></div>
<p>By default, <code class="highlighter-rouge">context</code> is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts like this <code class="highlighter-rouge">(m/module out {:contexts [(context/gpu)]})</code></p>
<p>Before you can compute with a module, you need to call <code class="highlighter-rouge">bind</code> to allocate the device memory and <code class="highlighter-rouge">init-params</code> or <code class="highlighter-rouge">set-params</code> to initialize the parameters. If you simply want to fit a module, you don’t need to call <code class="highlighter-rouge">bind</code> and <code class="highlighter-rouge">init-params</code> explicitly, because the <code class="highlighter-rouge">fit</code> function automatically calls them if they are needed.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">mod</span><span class="w"> </span><span class="p">(</span><span class="nf">m/module</span><span class="w"> </span><span class="n">out</span><span class="p">)]</span><span class="w">
</span><span class="p">(</span><span class="nb">-&gt;</span><span class="w"> </span><span class="n">mod</span><span class="w">
</span><span class="p">(</span><span class="nf">m/bind</span><span class="w"> </span><span class="p">{</span><span class="no">:data-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-data</span><span class="w"> </span><span class="n">train-data</span><span class="p">)</span><span class="w">
</span><span class="no">:label-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-label</span><span class="w"> </span><span class="n">train-data</span><span class="p">)})</span><span class="w">
</span><span class="p">(</span><span class="nf">m/init-params</span><span class="p">)))</span><span class="w">
</span></code></pre></div></div>
<p>Now you can compute with the module using functions like <code class="highlighter-rouge">forward</code>, <code class="highlighter-rouge">backward</code>, etc.</p>
<h2 id="training-and-predicting">Training and Predicting</h2>
<p>Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the <code class="highlighter-rouge">fit</code> function with some data iterators:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">(</span><span class="nf">m/fit</span><span class="w"> </span><span class="p">(</span><span class="nf">m/module</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="no">:train-data</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="no">:num-epoch</span><span class="w"> </span><span class="mi">1</span><span class="p">}))</span><span class="w">
</span><span class="c1">;; Epoch 0 Train- [accuracy 0.12521666]</span><span class="w">
</span><span class="c1">;; Epoch 0 Time cost- 8392</span><span class="w">
</span><span class="c1">;; Epoch 0 Validation- [accuracy 0.2227]</span><span class="w">
</span></code></pre></div></div>
<p>You can pass in batch-end callbacks using batch-end-callback and epoch-end callbacks using epoch-end-callback in the <code class="highlighter-rouge">fit-params</code>. You can also set parameters using functions like in the fit-params like optimizer and eval-metric. To learn more about the fit-params, see the fit-param function options. To predict with a module, call <code class="highlighter-rouge">predict</code> with a DataIter:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">results</span><span class="w"> </span><span class="p">(</span><span class="nf">m/predict</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="p">}))</span><span class="w">
</span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="n">results</span><span class="p">)</span><span class="w"> </span><span class="c1">;=&gt;#object[org.apache.mxnet.NDArray 0x3540b6d3 "org.apache.mxnet.NDArray@a48686ec"]</span><span class="w">
</span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="p">(</span><span class="nf">ndarray/-&gt;vec</span><span class="w"> </span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="n">results</span><span class="p">)))</span><span class="w"> </span><span class="c1">;=&gt;0.08261358</span><span class="w">
</span></code></pre></div></div>
<p>The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the <a href="/versions/master/api/clojure/docs/api/org.apache.clojure-mxnet.module.html#var-predict"><code class="highlighter-rouge">predict</code></a> function.</p>
<p>When prediction results might be too large to fit in memory, use the <a href="/versions/master/api/clojure/docs/api/org.apache.clojure-mxnet.module.html#var-predict-every-batch"><code class="highlighter-rouge">predict-every-batch</code></a> API.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">preds</span><span class="w"> </span><span class="p">(</span><span class="nf">m/predict-every-batch</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="p">})]</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/reduce-batches</span><span class="w"> </span><span class="n">test-data</span><span class="w">
</span><span class="p">(</span><span class="k">fn</span><span class="w"> </span><span class="p">[</span><span class="n">i</span><span class="w"> </span><span class="n">batch</span><span class="p">]</span><span class="w">
</span><span class="p">(</span><span class="nb">println</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="s">"pred is "</span><span class="w"> </span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="p">(</span><span class="nb">get</span><span class="w"> </span><span class="n">preds</span><span class="w"> </span><span class="n">i</span><span class="p">))))</span><span class="w">
</span><span class="p">(</span><span class="nb">println</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="s">"label is "</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/batch-label</span><span class="w"> </span><span class="n">batch</span><span class="p">)))</span><span class="w">
</span><span class="c1">;;; do something</span><span class="w">
</span><span class="p">(</span><span class="nb">inc</span><span class="w"> </span><span class="n">i</span><span class="p">))))</span><span class="w">
</span></code></pre></div></div>
<p>If you need to evaluate on a test set and don’t need the prediction output, call the <code class="highlighter-rouge">score</code> function with a data iterator and an eval metric:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nf">m/score</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="no">:eval-metric</span><span class="w"> </span><span class="p">(</span><span class="nf">eval-metric/accuracy</span><span class="p">)})</span><span class="w"> </span><span class="c1">;=&gt;["accuracy" 0.2227]</span><span class="w">
</span></code></pre></div></div>
<p>This runs predictions on each batch in the provided data iterator and computes the evaluation score using the provided eval metric. The evaluation results are stored in <code class="highlighter-rouge">eval-metric</code> object itself so that you can query later.</p>
<h2 id="saving-and-loading">Saving and Loading</h2>
<p>To save the module parameters in each training epoch, use the <code class="highlighter-rouge">save-checkpoint</code> function:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">save-prefix</span><span class="w"> </span><span class="s">"my-model"</span><span class="p">]</span><span class="w">
</span><span class="p">(</span><span class="nb">doseq</span><span class="w"> </span><span class="p">[</span><span class="n">epoch-num</span><span class="w"> </span><span class="p">(</span><span class="nb">range</span><span class="w"> </span><span class="mi">3</span><span class="p">)]</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/do-batches</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="p">(</span><span class="k">fn</span><span class="w"> </span><span class="p">[</span><span class="n">batch</span><span class="w">
</span><span class="c1">;; do something</span><span class="w">
</span><span class="p">]))</span><span class="w">
</span><span class="p">(</span><span class="nf">m/save-checkpoint</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:prefix</span><span class="w"> </span><span class="n">save-prefix</span><span class="w"> </span><span class="no">:epoch</span><span class="w"> </span><span class="n">epoch-num</span><span class="w"> </span><span class="no">:save-opt-states</span><span class="w"> </span><span class="n">true</span><span class="p">})))</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0000.params</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0000.states</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0001.params</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0001.states</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0002.params</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0002.states</span><span class="w">
</span></code></pre></div></div>
<p>To load the saved module parameters, call the <code class="highlighter-rouge">load-checkpoint</code> function:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">new-mod</span><span class="w"> </span><span class="p">(</span><span class="nf">m/load-checkpoint</span><span class="w"> </span><span class="p">{</span><span class="no">:prefix</span><span class="w"> </span><span class="s">"my-model"</span><span class="w"> </span><span class="no">:epoch</span><span class="w"> </span><span class="mi">1</span><span class="w"> </span><span class="no">:load-optimizer-states</span><span class="w"> </span><span class="n">true</span><span class="p">}))</span><span class="w">
</span><span class="n">new-mod</span><span class="w"> </span><span class="c1">;=&gt; #object[org.apache.mxnet.module.Module 0x5304d0f4 "org.apache.mxnet.module.Module@5304d0f4"]</span><span class="w">
</span></code></pre></div></div>
<p>To initialize parameters, Bind the symbols to construct executors first with <code class="highlighter-rouge">bind</code> function. Then, initialize the parameters and auxiliary states by calling <code class="highlighter-rouge">init-params</code> function.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nb">-&gt;</span><span class="w"> </span><span class="n">new-mod</span><span class="w">
</span><span class="p">(</span><span class="nf">m/bind</span><span class="w"> </span><span class="p">{</span><span class="no">:data-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-data</span><span class="w"> </span><span class="n">train-data</span><span class="p">)</span><span class="w"> </span><span class="no">:label-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-label</span><span class="w"> </span><span class="n">train-data</span><span class="p">)})</span><span class="w">
</span><span class="p">(</span><span class="nf">m/init-params</span><span class="p">))</span><span class="w">
</span></code></pre></div></div>
<p>To get current parameters, use <code class="highlighter-rouge">params</code></p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="w">
</span><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[[</span><span class="n">arg-params</span><span class="w"> </span><span class="n">aux-params</span><span class="p">]</span><span class="w"> </span><span class="p">(</span><span class="nf">m/params</span><span class="w"> </span><span class="n">new-mod</span><span class="p">)]</span><span class="w">
</span><span class="p">{</span><span class="no">:arg-params</span><span class="w"> </span><span class="n">arg-params</span><span class="w">
</span><span class="no">:aux-params</span><span class="w"> </span><span class="n">aux-params</span><span class="p">})</span><span class="w">
</span><span class="c1">;; {:arg-params</span><span class="w">
</span><span class="c1">;; {"fc3_bias"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x39adc3b0 "org.apache.mxnet.NDArray@49caf426"],</span><span class="w">
</span><span class="c1">;; "fc2_weight"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x25baf623 "org.apache.mxnet.NDArray@a6c8f9ac"],</span><span class="w">
</span><span class="c1">;; "fc1_bias"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x6e089973 "org.apache.mxnet.NDArray@9f91d6eb"],</span><span class="w">
</span><span class="c1">;; "fc3_weight"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x756fd109 "org.apache.mxnet.NDArray@2dd0fe3c"],</span><span class="w">
</span><span class="c1">;; "fc2_bias"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x1dc69c8b "org.apache.mxnet.NDArray@d128f73d"],</span><span class="w">
</span><span class="c1">;; "fc1_weight"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x20abc769 "org.apache.mxnet.NDArray@b8e1c5e8"]},</span><span class="w">
</span><span class="c1">;; :aux-params {}}</span><span class="w">
</span></code></pre></div></div>
<p>To assign parameter and aux state values, use the <code class="highlighter-rouge">set-params</code> function.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nf">m/set-params</span><span class="w"> </span><span class="n">new-mod</span><span class="w"> </span><span class="p">{</span><span class="no">:arg-params</span><span class="w"> </span><span class="p">(</span><span class="nf">m/arg-params</span><span class="w"> </span><span class="n">new-mod</span><span class="p">)</span><span class="w"> </span><span class="no">:aux-params</span><span class="w"> </span><span class="p">(</span><span class="nf">m/aux-params</span><span class="w"> </span><span class="n">new-mod</span><span class="p">)})</span><span class="w">
</span><span class="c1">;=&gt; #object[org.apache.mxnet.module.Module 0x5304d0f4 "org.apache.mxnet.module.Module@5304d0f4"]</span><span class="w">
</span></code></pre></div></div>
<p>To resume training from a saved checkpoint, pass the loaded parameters to the <code class="highlighter-rouge">fit</code> function. This will prevent <code class="highlighter-rouge">fit</code> from initialzing randomly.</p>
<p>Create fit-params and then use it to set <code class="highlighter-rouge">begin-epoch</code> so that <code class="highlighter-rouge">fit</code> knows to resume from a saved epoch.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">;; reset the training data before calling fit or you will get an error</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/reset</span><span class="w"> </span><span class="n">train-data</span><span class="p">)</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/reset</span><span class="w"> </span><span class="n">test-data</span><span class="p">)</span><span class="w">
</span><span class="p">(</span><span class="nf">m/fit</span><span class="w"> </span><span class="n">new-mod</span><span class="w"> </span><span class="p">{</span><span class="no">:train-data</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="no">:num-epoch</span><span class="w"> </span><span class="mi">2</span><span class="w">
</span><span class="no">:fit-params</span><span class="w"> </span><span class="p">(</span><span class="nb">-&gt;</span><span class="w"> </span><span class="p">(</span><span class="nf">m/fit-params</span><span class="w"> </span><span class="p">{</span><span class="no">:begin-epoch</span><span class="w"> </span><span class="mi">1</span><span class="p">}))})</span><span class="w">
</span></code></pre></div></div>
<h2 id="next-steps">Next Steps</h2>
<ul>
<li>See <a href="symbol">Symbolic API</a> for operations on NDArrays that assemble neural networks from layers.</li>
<li>See <a href="ndarray">NDArray API</a> for vector/matrix/tensor operations.</li>
<li>See <a href="kvstore">KVStore API</a> for multi-GPU and multi-host distributed training.</li>
</ul>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/master/community#stay-connected">Mailing lists</a></li>
<li><a href="/versions/master/community#github-issues">Github Issues</a></li>
<li><a href="https://github.com/apache/mxnet/projects">Projects</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/master/community">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/master/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>