blob: e233897427b1db33f9517a697b39c51e6daaec76 [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/1.9.1/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Module API | Apache MXNet</title>
<meta name="generator" content="Jekyll v3.8.6" />
<meta property="og:title" content="Module API" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/1.9.1/api/scala/docs/tutorials/module" />
<meta property="og:url" content="https://mxnet.apache.org/versions/1.9.1/api/scala/docs/tutorials/module" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"description":"A flexible and efficient library for deep learning.","headline":"Module API","@type":"WebPage","url":"https://mxnet.apache.org/versions/1.9.1/api/scala/docs/tutorials/module","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/1.9.1/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/1.9.1/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.1/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/1.9.1/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/1.9.1/assets/js/docsearch.min.js"></script><script src="/versions/1.9.1/assets/js/globalSearch.js" defer></script>
<script src="/versions/1.9.1/assets/js/clipboard.js" defer></script>
<script src="/versions/1.9.1/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/1.9.1/"><img
src="/versions/1.9.1/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">1.9.1</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/1.9.1/get_started">Get Started</a>
<a class="page-link" href="/versions/1.9.1/features">Features</a>
<a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/1.9.1/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/1.9.1/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">1.9.1
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a href="/">master</a>
<a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Module API</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<div class="docs-card docs-side">
<ul>
<div class="docs-action-btn">
<a href="/versions/1.9.1/api/scala.html"> <img src="/versions/1.9.1/assets/img/compass.svg"
class="docs-logo-docs">Scala Guide <span
class="span-accented"></span></a>
</div>
<div class="docs-action-btn">
<a href="/versions/1.9.1/api/scala/docs/tutorials"> <img
src="/versions/1.9.1/assets/img/video-tutorial.svg" class="docs-logo-docs">Scala
Tutorials <span class="span-accented"></span></a>
</div>
<div class="docs-action-btn">
<a href="/versions/1.9.1/api/scala/docs/api"> <img src="/versions/1.9.1/assets/img/api.svg"
class="docs-logo-docs">Scala API Reference
<span class="span-accented"></span></a>
</div>
<!-- Let's show the list of tutorials -->
<br>
<h3>Tutorials</h3>
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/char_lstm">Char-LSTM</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/infer">Infer API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/io">Data Loading API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/kvstore">KVStore API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/mnist">MNIST Example</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/model">Model API *Deprecated*</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/module">Module API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/mxnet_scala_on_intellij">Scala on IntelliJ</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/ndarray">NDArray</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/symbol">Symbol API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/symbol_in_pictures">Symbol in Pictures</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="module-api">Module API</h1>
<p>The module API provides an intermediate and high-level interface for performing computation with neural networks in MXNet. A <em>module</em> is an instance of subclasses of the <code>BaseModule</code>. The most widely used module class is called <code>Module</code>. Module wraps a <code>Symbol</code> and one or more <code>Executors</code>. For a full list of functions, see <code>BaseModule</code>.
A subclass of modules might have extra interface functions. This topic provides some examples of common use cases. All of the module APIs are in the <code>Module</code> namespace.</p>
<h2 id="preparing-a-module-for-computation">Preparing a Module for Computation</h2>
<p>To construct a module, refer to the constructors for the module class. For example, the <code>Module</code> class accepts a <code>Symbol</code> as input:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="k">import</span> <span class="nn">org.apache.mxnet._</span>
<span class="k">import</span> <span class="nn">org.apache.mxnet.module.</span><span class="o">{</span><span class="nc">FitParams</span><span class="o">,</span> <span class="nc">Module</span><span class="o">}</span>
<span class="c1">// construct a simple MLP
</span> <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">Variable</span><span class="o">(</span><span class="s">"data"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">fc1</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">FullyConnected</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">data</span><span class="o">),</span> <span class="n">num_hidden</span> <span class="k">=</span> <span class="mi">128</span><span class="o">,</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"fc1"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">act1</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">Activation</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">fc1</span><span class="o">),</span> <span class="s">"relu"</span><span class="o">,</span> <span class="s">"relu1"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">fc2</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">FullyConnected</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">act1</span><span class="o">),</span> <span class="n">num_hidden</span> <span class="k">=</span> <span class="mi">64</span><span class="o">,</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"fc2"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">act2</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">Activation</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">fc2</span><span class="o">),</span> <span class="s">"relu"</span><span class="o">,</span> <span class="s">"relu2"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">fc3</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">FullyConnected</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">act2</span><span class="o">),</span> <span class="n">num_hidden</span> <span class="k">=</span> <span class="mi">10</span><span class="o">,</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"fc3"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">out</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">SoftmaxOutput</span><span class="o">(</span><span class="n">fc3</span><span class="o">,</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"softmax"</span><span class="o">)</span>
<span class="c1">// construct the module
</span> <span class="k">val</span> <span class="nv">mod</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Module</span><span class="o">(</span><span class="n">out</span><span class="o">)</span>
</code></pre></div>
<p>By default, <code>context</code> is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts.</p>
<p>Before you can compute with a module, you need to call <code>bind()</code> to allocate the device memory and <code>initParams()</code> or <code>SetParams()</code> to initialize the parameters.
If you simply want to fit a module, you don&#39;t need to call <code>bind()</code> and <code>initParams()</code> explicitly, because the fit() function automatically calls them if they are needed.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nv">mod</span><span class="o">.</span><span class="py">bind</span><span class="o">(</span><span class="n">dataShapes</span> <span class="k">=</span> <span class="nv">train_dataiter</span><span class="o">.</span><span class="py">provideData</span><span class="o">,</span> <span class="n">labelShapes</span> <span class="k">=</span> <span class="nc">Some</span><span class="o">(</span><span class="nv">train_dataiter</span><span class="o">.</span><span class="py">provideLabel</span><span class="o">))</span>
<span class="nv">mod</span><span class="o">.</span><span class="py">initParams</span><span class="o">()</span>
</code></pre></div>
<p>Now you can compute with the module using functions like <code>forward()</code>, <code>backward()</code>, etc.</p>
<h2 id="training-predicting-and-evaluating">Training, Predicting, and Evaluating</h2>
<p>Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the <code>fit()</code> function with some <code>DataIter</code>s:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="k">import</span> <span class="nn">org.apache.mxnet.optimizer.SGD</span>
<span class="k">val</span> <span class="nv">mod</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Module</span><span class="o">(</span><span class="n">softmax</span><span class="o">)</span>
<span class="nv">mod</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">train_dataiter</span><span class="o">,</span> <span class="n">evalData</span> <span class="k">=</span> <span class="nv">scala</span><span class="o">.</span><span class="py">Option</span><span class="o">(</span><span class="n">eval_dataiter</span><span class="o">),</span> <span class="o">\</span>
<span class="n">numEpoch</span> <span class="k">=</span> <span class="n">n_epoch</span><span class="o">,</span> <span class="n">fitParams</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">FitParams</span><span class="o">()\</span>
<span class="o">.</span><span class="py">setOptimizer</span><span class="o">(</span><span class="k">new</span> <span class="nc">SGD</span><span class="o">(</span><span class="n">learningRate</span> <span class="k">=</span> <span class="mf">0.1f</span><span class="o">,</span> <span class="n">momentum</span> <span class="k">=</span> <span class="mf">0.9f</span><span class="o">,</span> <span class="n">wd</span> <span class="k">=</span> <span class="mf">0.0001f</span><span class="o">)))</span>
</code></pre></div>
<p>The interface is very similar to the old <code>FeedForward</code> class. You can pass in batch-end callbacks using <code>setBatchEndCallback</code> and epoch-end callbacks using <code>setEpochEndCallback</code>. You can also set parameters using methods like <code>setOptimizer</code> and <code>setEvalMetric</code>. To learn more about the <code>FitParams()</code>, see the <a href="/versions/1.9.1/api/scala/docs/api/#org.apache.mxnet.module.FitParams">API page</a>. To predict with a module, call <code>predict()</code> with a <code>DataIter</code>:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nv">mod</span><span class="o">.</span><span class="py">predict</span><span class="o">(</span><span class="n">val_dataiter</span><span class="o">)</span>
</code></pre></div>
<p>The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the <a href="/versions/1.9.1/api/scala/docs/api/#org.apache.mxnet.module.BaseModule"><code>predict()</code> function</a>.</p>
<p>When prediction results might be too large to fit in memory, use the <code>predictEveryBatch</code> API:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="k">val</span> <span class="nv">preds</span> <span class="k">=</span> <span class="nv">mod</span><span class="o">.</span><span class="py">predictEveryBatch</span><span class="o">(</span><span class="n">val_dataiter</span><span class="o">)</span>
<span class="nv">val_dataiter</span><span class="o">.</span><span class="py">reset</span><span class="o">()</span>
<span class="k">var</span> <span class="n">i</span> <span class="k">=</span> <span class="mi">0</span>
<span class="nf">while</span> <span class="o">(</span><span class="nv">val_dataiter</span><span class="o">.</span><span class="py">hasNext</span><span class="o">)</span> <span class="o">{</span>
<span class="k">val</span> <span class="nv">batch</span> <span class="k">=</span> <span class="nv">val_dataiter</span><span class="o">.</span><span class="py">next</span><span class="o">()</span>
<span class="k">val</span> <span class="nv">predLabel</span><span class="k">:</span> <span class="kt">Array</span><span class="o">[</span><span class="kt">Int</span><span class="o">]</span> <span class="k">=</span> <span class="nv">NDArray</span><span class="o">.</span><span class="py">argmax_channel</span><span class="o">(</span><span class="nf">preds</span><span class="o">(</span><span class="n">i</span><span class="o">)(</span><span class="mi">0</span><span class="o">)).</span><span class="py">toArray</span><span class="o">.</span><span class="py">map</span><span class="o">(</span><span class="nv">_</span><span class="o">.</span><span class="py">toInt</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">label</span> <span class="k">=</span> <span class="nv">batch</span><span class="o">.</span><span class="py">label</span><span class="o">(</span><span class="mi">0</span><span class="o">).</span><span class="py">toArray</span><span class="o">.</span><span class="py">map</span><span class="o">(</span><span class="nv">_</span><span class="o">.</span><span class="py">toInt</span><span class="o">)</span>
<span class="c1">//do something...
</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="o">}</span>
</code></pre></div>
<p>If you need to evaluate on a test set and don&#39;t need the prediction output, call the <code>score()</code> function with a <code>DataIter</code> and an <code>EvalMetric</code>:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nv">mod</span><span class="o">.</span><span class="py">score</span><span class="o">(</span><span class="n">val_dataiter</span><span class="o">,</span> <span class="n">metric</span><span class="o">)</span>
</code></pre></div>
<p>This runs predictions on each batch in the provided <code>DataIter</code> and computes the evaluation score using the provided <code>EvalMetric</code>. The evaluation results are stored in <code>metric</code> so that you can query later.</p>
<h2 id="saving-and-loading-module-parameters">Saving and Loading Module Parameters</h2>
<p>To save the module parameters in each training epoch, use a <code>checkpoint</code> callback:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="k">val</span> <span class="nv">modelPrefix</span><span class="k">:</span> <span class="kt">String</span> <span class="o">=</span> <span class="s">"mymodel"</span>
<span class="nf">for</span> <span class="o">(</span><span class="n">epoch</span> <span class="k">&lt;-</span> <span class="mi">0</span> <span class="n">until</span> <span class="mi">5</span><span class="o">)</span> <span class="o">{</span>
<span class="nf">while</span><span class="o">(</span><span class="nv">train_dataiter</span><span class="o">.</span><span class="py">hasNext</span><span class="o">){</span>
<span class="c1">// forward backward pass
</span> <span class="c1">//do something...
</span> <span class="o">}</span>
<span class="k">val</span> <span class="nv">checkpoint</span> <span class="k">=</span> <span class="nv">mod</span><span class="o">.</span><span class="py">saveCheckpoint</span><span class="o">(</span><span class="n">modelPrefix</span><span class="o">,</span> <span class="n">epoch</span><span class="o">,</span> <span class="n">saveOptStates</span> <span class="k">=</span> <span class="kc">true</span><span class="o">)</span>
<span class="o">}</span>
</code></pre></div>
<p>To load the saved module parameters, call the <code>loadCheckpoint</code> function:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="k">val</span> <span class="nv">mod</span> <span class="k">=</span> <span class="nv">Module</span><span class="o">.</span><span class="py">loadCheckpoint</span><span class="o">(</span><span class="n">modelPrefix</span><span class="o">,</span> <span class="n">loadModelEpoch</span><span class="o">,</span> <span class="n">loadOptimizerStates</span> <span class="k">=</span> <span class="kc">true</span><span class="o">)</span>
</code></pre></div>
<p>To initialize parameters, Bind the symbols to construct executors first with <code>bind</code> method. Then, initialize the parameters and auxiliary states by calling <code>initParams()</code> method.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nv">mod</span><span class="o">.</span><span class="py">bind</span><span class="o">(</span><span class="n">dataShapes</span> <span class="k">=</span> <span class="nv">train_dataiter</span><span class="o">.</span><span class="py">provideData</span><span class="o">,</span> <span class="n">labelShapes</span> <span class="k">=</span> <span class="nc">Some</span><span class="o">(</span><span class="nv">train_dataiter</span><span class="o">.</span><span class="py">provideLabel</span><span class="o">))</span>
<span class="nv">mod</span><span class="o">.</span><span class="py">initParams</span><span class="o">()</span>
</code></pre></div>
<p>To get current parameters, use <code>getParams</code> method.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nf">val</span> <span class="o">(</span><span class="n">argParams</span><span class="o">,</span> <span class="n">auxParams</span><span class="o">)</span> <span class="k">=</span> <span class="nv">mod</span><span class="o">.</span><span class="py">getParams</span>
</code></pre></div>
<p>To assign parameter and aux state values, use <code>setParams</code> method.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nv">mod</span><span class="o">.</span><span class="py">setParams</span><span class="o">(</span><span class="n">argParams</span><span class="o">,</span> <span class="n">auxParams</span><span class="o">)</span>
</code></pre></div>
<p>To resume training from a saved checkpoint, instead of calling <code>setParams()</code>, directly call <code>fit()</code>, passing the loaded parameters, so that <code>fit()</code> knows to start from those parameters instead of initializing randomly:</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="nv">mod</span><span class="o">.</span><span class="py">fit</span><span class="o">(...,</span> <span class="n">fitParams</span><span class="k">=new</span> <span class="nc">FitParams</span><span class="o">().</span><span class="py">setArgParams</span><span class="o">(</span><span class="n">argParams</span><span class="o">).\</span>
<span class="nf">setAuxParams</span><span class="o">(</span><span class="n">auxParams</span><span class="o">).</span><span class="py">setBeginEpoch</span><span class="o">(</span><span class="n">beginEpoch</span><span class="o">))</span>
</code></pre></div>
<p>Create an object of the <code>FitParams()</code> class, and then use it to call the <code>setBeginEpoch()</code> method to pass <code>beginEpoch</code> so that <code>fit()</code> knows to resume from a saved epoch.</p>
<h2 id="next-steps">Next Steps</h2>
<ul>
<li>See <a href="model">Model API</a> for an alternative simple high-level interface for training neural networks.</li>
<li>See <a href="symbol">Symbolic API</a> for operations on NDArrays that assemble neural networks from layers.</li>
<li>See <a href="io">IO Data Loading API</a> for parsing and loading data.</li>
<li>See <a href="ndarray">NDArray API</a> for vector/matrix/tensor operations.</li>
<li>See <a href="kvstore">KVStore API</a> for multi-GPU and multi-host distributed training.</li>
</ul>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/1.9.1/community/contribute#mxnet-dev-communications">Mailing lists</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
<li><a href="https://github.com/apache/mxnet/labels/Roadmap">Github Roadmap</a></li>
<li><a href="https://medium.com/apache-mxnet">Blog</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/1.9.1/community/contribute">Contribute</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/1.9.1/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>