blob: d69ad443df30a3fd66b8e535bacbd7b07f5d38bb [file] [log] [blame]
<!DOCTYPE html>
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Module API | Apache MXNet</title>
<meta name="generator" content="Jekyll v4.0.0" />
<meta property="og:title" content="Module API" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/master/api/clojure/docs/tutorials/module" />
<meta property="og:url" content="https://mxnet.apache.org/versions/master/api/clojure/docs/tutorials/module" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"url":"https://mxnet.apache.org/versions/master/api/clojure/docs/tutorials/module","headline":"Module API","description":"A flexible and efficient library for deep learning.","@type":"WebPage","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<script src="https://medium-widget.pixelpoint.io/widget.js"></script>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.css" />
<link rel="stylesheet" href="/versions/master/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/master/feed.xml" title="Apache MXNet" /><script>
if(!(window.doNotTrack === "1" || navigator.doNotTrack === "1" || navigator.doNotTrack === "yes" || navigator.msDoNotTrack === "1")) {
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
}
</script>
<script src="/versions/master/assets/js/jquery-3.3.1.min.js"></script><script src="https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js" defer></script>
<script src="/versions/master/assets/js/globalSearch.js" defer></script>
<script src="/versions/master/assets/js/clipboard.js" defer></script>
<script src="/versions/master/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/master/"><img
src="/versions/master/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">master</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions active">master</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/master/get_started">Get Started</a>
<a class="page-link" href="/versions/master/blog">Blog</a>
<a class="page-link" href="/versions/master/features">Features</a>
<a class="page-link" href="/versions/master/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/master/api">Docs & Tutorials</a>
<a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a>
<div class="dropdown">
<span class="dropdown-header">master
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option-active" href="/">master</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Module API</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="module-api">Module API</h1>
<p>The module API provides an intermediate and high-level interface for performing computation with neural networks in MXNet. Module wraps a Symbol and one or more Executors. It has both a high level and intermediate level API.</p>
<p>Topics:</p>
<ul>
<li><a href="#prepare-the-data">Prepare the Data</a></li>
<li><a href="#list-key-value-pairs">List Key-Value Pairs</a></li>
<li><a href="#preparing-a-module-for-computation">Preparing a Module for Computation</a></li>
<li><a href="#training-and-predicting">Training and Predicting</a></li>
<li><a href="#saving-and-loading">Saving and Loading</a></li>
<li><a href="/versions/master/api/clojure/docs/api">Clojure API Reference</a></li>
</ul>
<p>To follow along with this documentation, you can use this namespace to with the needed requires:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nf">ns</span><span class="w"> </span><span class="n">docs.module</span><span class="w">
</span><span class="p">(</span><span class="no">:require</span><span class="w"> </span><span class="p">[</span><span class="n">clojure.java.io</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">io</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">clojure.java.shell</span><span class="w"> </span><span class="no">:refer</span><span class="w"> </span><span class="p">[</span><span class="n">sh</span><span class="p">]]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.eval-metric</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">eval-metric</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.io</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">mx-io</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.module</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">m</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.symbol</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">sym</span><span class="p">]</span><span class="w">
</span><span class="p">[</span><span class="n">org.apache.clojure-mxnet.ndarray</span><span class="w"> </span><span class="no">:as</span><span class="w"> </span><span class="n">ndarray</span><span class="p">]))</span><span class="w">
</span></code></pre></div></div>
<h2 id="prepare-the-data">Prepare the Data</h2>
<p>In this example, we are going to use the MNIST data set. If you have cloned the MXNet repo and <code class="highlighter-rouge">cd contrib/clojure-package</code>, we can run some helper scripts to download the data for us.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"data/"</span><span class="p">)</span><span class="w">
</span><span class="p">(</span><span class="nb">when-not</span><span class="w"> </span><span class="p">(</span><span class="nf">.exists</span><span class="w"> </span><span class="p">(</span><span class="nf">io/file</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"train-images-idx3-ubyte"</span><span class="p">)))</span><span class="w">
</span><span class="p">(</span><span class="nf">sh</span><span class="w"> </span><span class="s">"../../scripts/get_mnist_data.sh"</span><span class="p">))</span><span class="w">
</span></code></pre></div></div>
<p>MXNet provides function in the <code class="highlighter-rouge">io</code> namespace to load the MNIST datasets into training and test data iterators that we can use with our module.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/mnist-iter</span><span class="w"> </span><span class="p">{</span><span class="no">:image</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"train-images-idx3-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:label</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"train-labels-idx1-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:label-name</span><span class="w"> </span><span class="s">"softmax_label"</span><span class="w">
</span><span class="no">:input-shape</span><span class="w"> </span><span class="p">[</span><span class="mi">784</span><span class="p">]</span><span class="w">
</span><span class="no">:batch-size</span><span class="w"> </span><span class="mi">10</span><span class="w">
</span><span class="no">:shuffle</span><span class="w"> </span><span class="n">true</span><span class="w">
</span><span class="no">:flat</span><span class="w"> </span><span class="n">true</span><span class="w">
</span><span class="no">:silent</span><span class="w"> </span><span class="n">false</span><span class="w">
</span><span class="no">:seed</span><span class="w"> </span><span class="mi">10</span><span class="p">}))</span><span class="w">
</span><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/mnist-iter</span><span class="w"> </span><span class="p">{</span><span class="no">:image</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"t10k-images-idx3-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:label</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="n">data-dir</span><span class="w"> </span><span class="s">"t10k-labels-idx1-ubyte"</span><span class="p">)</span><span class="w">
</span><span class="no">:input-shape</span><span class="w"> </span><span class="p">[</span><span class="mi">784</span><span class="p">]</span><span class="w">
</span><span class="no">:batch-size</span><span class="w"> </span><span class="mi">10</span><span class="w">
</span><span class="no">:flat</span><span class="w"> </span><span class="n">true</span><span class="w">
</span><span class="no">:silent</span><span class="w"> </span><span class="n">false</span><span class="p">}))</span><span class="w">
</span></code></pre></div></div>
<h2 id="preparing-a-module-for-computation">Preparing a Module for Computation</h2>
<p>To construct a module, we need to have a symbol as input. This symbol takes input data in the first layer and then has subsequent layers of fully connected and relu activation layers, ending up in a softmax layer for output.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">data</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/variable</span><span class="w"> </span><span class="s">"data"</span><span class="p">)</span><span class="w">
</span><span class="n">fc1</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">128</span><span class="p">})</span><span class="w">
</span><span class="n">act1</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">fc1</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="n">fc2</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">act1</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">64</span><span class="p">})</span><span class="w">
</span><span class="n">act2</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">fc2</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="n">fc3</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc3"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">act2</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">10</span><span class="p">})</span><span class="w">
</span><span class="n">out</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/softmax-output</span><span class="w"> </span><span class="s">"softmax"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">fc3</span><span class="p">})]</span><span class="w">
</span><span class="n">out</span><span class="p">)</span><span class="w">
</span><span class="c1">;=&gt;#object[org.apache.mxnet.Symbol 0x1f43a406 "org.apache.mxnet.Symbol@1f43a406"]</span><span class="w">
</span></code></pre></div></div>
<p>You can also write this with the <code class="highlighter-rouge">as-&gt;</code> threading macro.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">out</span><span class="w"> </span><span class="p">(</span><span class="nf">as-&gt;</span><span class="w"> </span><span class="p">(</span><span class="nf">sym/variable</span><span class="w"> </span><span class="s">"data"</span><span class="p">)</span><span class="w"> </span><span class="n">data</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">128</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu1"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">64</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/activation</span><span class="w"> </span><span class="s">"relu2"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:act-type</span><span class="w"> </span><span class="s">"relu"</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/fully-connected</span><span class="w"> </span><span class="s">"fc3"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="no">:num-hidden</span><span class="w"> </span><span class="mi">10</span><span class="p">})</span><span class="w">
</span><span class="p">(</span><span class="nf">sym/softmax-output</span><span class="w"> </span><span class="s">"softmax"</span><span class="w"> </span><span class="p">{</span><span class="no">:data</span><span class="w"> </span><span class="n">data</span><span class="p">})))</span><span class="w">
</span><span class="c1">;=&gt; #'tutorial.module/out</span><span class="w">
</span></code></pre></div></div>
<p>By default, <code class="highlighter-rouge">context</code> is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts like this <code class="highlighter-rouge">(m/module out {:contexts [(context/gpu)]})</code></p>
<p>Before you can compute with a module, you need to call <code class="highlighter-rouge">bind</code> to allocate the device memory and <code class="highlighter-rouge">init-params</code> or <code class="highlighter-rouge">set-params</code> to initialize the parameters. If you simply want to fit a module, you don’t need to call <code class="highlighter-rouge">bind</code> and <code class="highlighter-rouge">init-params</code> explicitly, because the <code class="highlighter-rouge">fit</code> function automatically calls them if they are needed.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">mod</span><span class="w"> </span><span class="p">(</span><span class="nf">m/module</span><span class="w"> </span><span class="n">out</span><span class="p">)]</span><span class="w">
</span><span class="p">(</span><span class="nb">-&gt;</span><span class="w"> </span><span class="n">mod</span><span class="w">
</span><span class="p">(</span><span class="nf">m/bind</span><span class="w"> </span><span class="p">{</span><span class="no">:data-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-data</span><span class="w"> </span><span class="n">train-data</span><span class="p">)</span><span class="w">
</span><span class="no">:label-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-label</span><span class="w"> </span><span class="n">train-data</span><span class="p">)})</span><span class="w">
</span><span class="p">(</span><span class="nf">m/init-params</span><span class="p">)))</span><span class="w">
</span></code></pre></div></div>
<p>Now you can compute with the module using functions like <code class="highlighter-rouge">forward</code>, <code class="highlighter-rouge">backward</code>, etc.</p>
<h2 id="training-and-predicting">Training and Predicting</h2>
<p>Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the <code class="highlighter-rouge">fit</code> function with some data iterators:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">(</span><span class="nf">m/fit</span><span class="w"> </span><span class="p">(</span><span class="nf">m/module</span><span class="w"> </span><span class="n">out</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="no">:train-data</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="no">:num-epoch</span><span class="w"> </span><span class="mi">1</span><span class="p">}))</span><span class="w">
</span><span class="c1">;; Epoch 0 Train- [accuracy 0.12521666]</span><span class="w">
</span><span class="c1">;; Epoch 0 Time cost- 8392</span><span class="w">
</span><span class="c1">;; Epoch 0 Validation- [accuracy 0.2227]</span><span class="w">
</span></code></pre></div></div>
<p>You can pass in batch-end callbacks using batch-end-callback and epoch-end callbacks using epoch-end-callback in the <code class="highlighter-rouge">fit-params</code>. You can also set parameters using functions like in the fit-params like optimizer and eval-metric. To learn more about the fit-params, see the fit-param function options. To predict with a module, call <code class="highlighter-rouge">predict</code> with a DataIter:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">results</span><span class="w"> </span><span class="p">(</span><span class="nf">m/predict</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="p">}))</span><span class="w">
</span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="n">results</span><span class="p">)</span><span class="w"> </span><span class="c1">;=&gt;#object[org.apache.mxnet.NDArray 0x3540b6d3 "org.apache.mxnet.NDArray@a48686ec"]</span><span class="w">
</span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="p">(</span><span class="nf">ndarray/-&gt;vec</span><span class="w"> </span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="n">results</span><span class="p">)))</span><span class="w"> </span><span class="c1">;=&gt;0.08261358</span><span class="w">
</span></code></pre></div></div>
<p>The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the <a href="/versions/master/api/clojure/docs/api/org.apache.clojure-mxnet.module.html#var-predict"><code class="highlighter-rouge">predict</code></a> function.</p>
<p>When prediction results might be too large to fit in memory, use the <a href="/versions/master/api/clojure/docs/api/org.apache.clojure-mxnet.module.html#var-predict-every-batch"><code class="highlighter-rouge">predict-every-batch</code></a> API.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">preds</span><span class="w"> </span><span class="p">(</span><span class="nf">m/predict-every-batch</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="p">})]</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/reduce-batches</span><span class="w"> </span><span class="n">test-data</span><span class="w">
</span><span class="p">(</span><span class="k">fn</span><span class="w"> </span><span class="p">[</span><span class="n">i</span><span class="w"> </span><span class="n">batch</span><span class="p">]</span><span class="w">
</span><span class="p">(</span><span class="nb">println</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="s">"pred is "</span><span class="w"> </span><span class="p">(</span><span class="nb">first</span><span class="w"> </span><span class="p">(</span><span class="nb">get</span><span class="w"> </span><span class="n">preds</span><span class="w"> </span><span class="n">i</span><span class="p">))))</span><span class="w">
</span><span class="p">(</span><span class="nb">println</span><span class="w"> </span><span class="p">(</span><span class="nb">str</span><span class="w"> </span><span class="s">"label is "</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/batch-label</span><span class="w"> </span><span class="n">batch</span><span class="p">)))</span><span class="w">
</span><span class="c1">;;; do something</span><span class="w">
</span><span class="p">(</span><span class="nb">inc</span><span class="w"> </span><span class="n">i</span><span class="p">))))</span><span class="w">
</span></code></pre></div></div>
<p>If you need to evaluate on a test set and don’t need the prediction output, call the <code class="highlighter-rouge">score</code> function with a data iterator and an eval metric:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nf">m/score</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="no">:eval-metric</span><span class="w"> </span><span class="p">(</span><span class="nf">eval-metric/accuracy</span><span class="p">)})</span><span class="w"> </span><span class="c1">;=&gt;["accuracy" 0.2227]</span><span class="w">
</span></code></pre></div></div>
<p>This runs predictions on each batch in the provided data iterator and computes the evaluation score using the provided eval metric. The evaluation results are stored in <code class="highlighter-rouge">eval-metric</code> object itself so that you can query later.</p>
<h2 id="saving-and-loading">Saving and Loading</h2>
<p>To save the module parameters in each training epoch, use the <code class="highlighter-rouge">save-checkpoint</code> function:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[</span><span class="n">save-prefix</span><span class="w"> </span><span class="s">"my-model"</span><span class="p">]</span><span class="w">
</span><span class="p">(</span><span class="nb">doseq</span><span class="w"> </span><span class="p">[</span><span class="n">epoch-num</span><span class="w"> </span><span class="p">(</span><span class="nb">range</span><span class="w"> </span><span class="mi">3</span><span class="p">)]</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/do-batches</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="p">(</span><span class="k">fn</span><span class="w"> </span><span class="p">[</span><span class="n">batch</span><span class="w">
</span><span class="c1">;; do something</span><span class="w">
</span><span class="p">]))</span><span class="w">
</span><span class="p">(</span><span class="nf">m/save-checkpoint</span><span class="w"> </span><span class="n">mod</span><span class="w"> </span><span class="p">{</span><span class="no">:prefix</span><span class="w"> </span><span class="n">save-prefix</span><span class="w"> </span><span class="no">:epoch</span><span class="w"> </span><span class="n">epoch-num</span><span class="w"> </span><span class="no">:save-opt-states</span><span class="w"> </span><span class="n">true</span><span class="p">})))</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0000.params</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0000.states</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0001.params</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0001.states</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved checkpoint to my-model-0002.params</span><span class="w">
</span><span class="c1">;; INFO org.apache.mxnet.module.Module: Saved optimizer state to my-model-0002.states</span><span class="w">
</span></code></pre></div></div>
<p>To load the saved module parameters, call the <code class="highlighter-rouge">load-checkpoint</code> function:</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="k">def</span><span class="w"> </span><span class="n">new-mod</span><span class="w"> </span><span class="p">(</span><span class="nf">m/load-checkpoint</span><span class="w"> </span><span class="p">{</span><span class="no">:prefix</span><span class="w"> </span><span class="s">"my-model"</span><span class="w"> </span><span class="no">:epoch</span><span class="w"> </span><span class="mi">1</span><span class="w"> </span><span class="no">:load-optimizer-states</span><span class="w"> </span><span class="n">true</span><span class="p">}))</span><span class="w">
</span><span class="n">new-mod</span><span class="w"> </span><span class="c1">;=&gt; #object[org.apache.mxnet.module.Module 0x5304d0f4 "org.apache.mxnet.module.Module@5304d0f4"]</span><span class="w">
</span></code></pre></div></div>
<p>To initialize parameters, Bind the symbols to construct executors first with <code class="highlighter-rouge">bind</code> function. Then, initialize the parameters and auxiliary states by calling <code class="highlighter-rouge">init-params</code> function.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nb">-&gt;</span><span class="w"> </span><span class="n">new-mod</span><span class="w">
</span><span class="p">(</span><span class="nf">m/bind</span><span class="w"> </span><span class="p">{</span><span class="no">:data-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-data</span><span class="w"> </span><span class="n">train-data</span><span class="p">)</span><span class="w"> </span><span class="no">:label-shapes</span><span class="w"> </span><span class="p">(</span><span class="nf">mx-io/provide-label</span><span class="w"> </span><span class="n">train-data</span><span class="p">)})</span><span class="w">
</span><span class="p">(</span><span class="nf">m/init-params</span><span class="p">))</span><span class="w">
</span></code></pre></div></div>
<p>To get current parameters, use <code class="highlighter-rouge">params</code></p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="w">
</span><span class="p">(</span><span class="k">let</span><span class="w"> </span><span class="p">[[</span><span class="n">arg-params</span><span class="w"> </span><span class="n">aux-params</span><span class="p">]</span><span class="w"> </span><span class="p">(</span><span class="nf">m/params</span><span class="w"> </span><span class="n">new-mod</span><span class="p">)]</span><span class="w">
</span><span class="p">{</span><span class="no">:arg-params</span><span class="w"> </span><span class="n">arg-params</span><span class="w">
</span><span class="no">:aux-params</span><span class="w"> </span><span class="n">aux-params</span><span class="p">})</span><span class="w">
</span><span class="c1">;; {:arg-params</span><span class="w">
</span><span class="c1">;; {"fc3_bias"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x39adc3b0 "org.apache.mxnet.NDArray@49caf426"],</span><span class="w">
</span><span class="c1">;; "fc2_weight"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x25baf623 "org.apache.mxnet.NDArray@a6c8f9ac"],</span><span class="w">
</span><span class="c1">;; "fc1_bias"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x6e089973 "org.apache.mxnet.NDArray@9f91d6eb"],</span><span class="w">
</span><span class="c1">;; "fc3_weight"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x756fd109 "org.apache.mxnet.NDArray@2dd0fe3c"],</span><span class="w">
</span><span class="c1">;; "fc2_bias"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x1dc69c8b "org.apache.mxnet.NDArray@d128f73d"],</span><span class="w">
</span><span class="c1">;; "fc1_weight"</span><span class="w">
</span><span class="c1">;; #object[org.apache.mxnet.NDArray 0x20abc769 "org.apache.mxnet.NDArray@b8e1c5e8"]},</span><span class="w">
</span><span class="c1">;; :aux-params {}}</span><span class="w">
</span></code></pre></div></div>
<p>To assign parameter and aux state values, use the <code class="highlighter-rouge">set-params</code> function.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="nf">m/set-params</span><span class="w"> </span><span class="n">new-mod</span><span class="w"> </span><span class="p">{</span><span class="no">:arg-params</span><span class="w"> </span><span class="p">(</span><span class="nf">m/arg-params</span><span class="w"> </span><span class="n">new-mod</span><span class="p">)</span><span class="w"> </span><span class="no">:aux-params</span><span class="w"> </span><span class="p">(</span><span class="nf">m/aux-params</span><span class="w"> </span><span class="n">new-mod</span><span class="p">)})</span><span class="w">
</span><span class="c1">;=&gt; #object[org.apache.mxnet.module.Module 0x5304d0f4 "org.apache.mxnet.module.Module@5304d0f4"]</span><span class="w">
</span></code></pre></div></div>
<p>To resume training from a saved checkpoint, pass the loaded parameters to the <code class="highlighter-rouge">fit</code> function. This will prevent <code class="highlighter-rouge">fit</code> from initialzing randomly.</p>
<p>Create fit-params and then use it to set <code class="highlighter-rouge">begin-epoch</code> so that <code class="highlighter-rouge">fit</code> knows to resume from a saved epoch.</p>
<div class="language-clojure highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">;; reset the training data before calling fit or you will get an error</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/reset</span><span class="w"> </span><span class="n">train-data</span><span class="p">)</span><span class="w">
</span><span class="p">(</span><span class="nf">mx-io/reset</span><span class="w"> </span><span class="n">test-data</span><span class="p">)</span><span class="w">
</span><span class="p">(</span><span class="nf">m/fit</span><span class="w"> </span><span class="n">new-mod</span><span class="w"> </span><span class="p">{</span><span class="no">:train-data</span><span class="w"> </span><span class="n">train-data</span><span class="w"> </span><span class="no">:eval-data</span><span class="w"> </span><span class="n">test-data</span><span class="w"> </span><span class="no">:num-epoch</span><span class="w"> </span><span class="mi">2</span><span class="w">
</span><span class="no">:fit-params</span><span class="w"> </span><span class="p">(</span><span class="nb">-&gt;</span><span class="w"> </span><span class="p">(</span><span class="nf">m/fit-params</span><span class="w"> </span><span class="p">{</span><span class="no">:begin-epoch</span><span class="w"> </span><span class="mi">1</span><span class="p">}))})</span><span class="w">
</span></code></pre></div></div>
<h2 id="next-steps">Next Steps</h2>
<ul>
<li>See <a href="symbol">Symbolic API</a> for operations on NDArrays that assemble neural networks from layers.</li>
<li>See <a href="ndarray">NDArray API</a> for vector/matrix/tensor operations.</li>
<li>See <a href="kvstore">KVStore API</a> for multi-GPU and multi-host distributed training.</li>
</ul>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/master/community#stay-connected">Mailing lists</a></li>
<li><a href="https://discuss.mxnet.io">MXNet Discuss forum</a></li>
<li><a href="/versions/master/community#github-issues">Github Issues</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/projects">Projects</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="/versions/master/community">Contribute To MXNet</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/master/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/master/assets/img/apache_incubator_logo.png" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
<p>Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <span
style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required
of all newly accepted projects until a further review indicates that the infrastructure,
communications, and decision making process have stabilized in a manner consistent with other
successful ASF projects. While incubation status is not necessarily a reflection of the completeness
or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p><p>"Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>