blob: a8f1ba0e95af1d732f18657dadc9d0af7f2ba288 [file] [log] [blame]
<!DOCTYPE html>
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<html lang=" en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="/versions/1.9.1/assets/img/mxnet-icon.png" rel="icon" type="image/png"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Model API Deprecated | Apache MXNet</title>
<meta name="generator" content="Jekyll v3.8.6" />
<meta property="og:title" content="Model API Deprecated" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="A flexible and efficient library for deep learning." />
<meta property="og:description" content="A flexible and efficient library for deep learning." />
<link rel="canonical" href="https://mxnet.apache.org/versions/1.9.1/api/scala/docs/tutorials/model" />
<meta property="og:url" content="https://mxnet.apache.org/versions/1.9.1/api/scala/docs/tutorials/model" />
<meta property="og:site_name" content="Apache MXNet" />
<script type="application/ld+json">
{"description":"A flexible and efficient library for deep learning.","headline":"Model API Deprecated","@type":"WebPage","url":"https://mxnet.apache.org/versions/1.9.1/api/scala/docs/tutorials/model","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link rel="stylesheet" href="/versions/1.9.1/assets/docsearch.min.css" /><link rel="stylesheet" href="/versions/1.9.1/assets/main.css"><link type="application/atom+xml" rel="alternate" href="https://mxnet.apache.org/versions/1.9.1/feed.xml" title="Apache MXNet" /><!-- Matomo -->
<script>
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
/* We explicitly disable cookie tracking to avoid privacy issues */
_paq.push(['disableCookies']);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '23']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
<script src="/versions/1.9.1/assets/js/jquery-3.3.1.min.js"></script>
<script src="/versions/1.9.1/assets/js/docsearch.min.js"></script><script src="/versions/1.9.1/assets/js/globalSearch.js" defer></script>
<script src="/versions/1.9.1/assets/js/clipboard.js" defer></script>
<script src="/versions/1.9.1/assets/js/copycode.js" defer></script></head>
<body><header class="site-header" role="banner">
<script>
$(document).ready(function () {
// HEADER OPACITY LOGIC
function opacity_header() {
var value = "rgba(4,140,204," + ($(window).scrollTop() / 300 + 0.4) + ")"
$('.site-header').css("background-color", value)
}
$(window).scroll(function () {
opacity_header()
})
opacity_header();
// MENU SELECTOR LOGIC
$('.page-link').each( function () {
if (window.location.href.includes(this.href)) {
$(this).addClass("page-current");
}
});
})
</script>
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/1.9.1/"><img
src="/versions/1.9.1/assets/img/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="gs-search-border">
<div id="gs-search-icon"></div>
<form id="global-search-form">
<input id="global-search" type="text" title="Search" placeholder="Search" />
<div id="global-search-dropdown-container">
<button class="gs-current-version btn" type="button" data-toggle="dropdown">
<span id="gs-current-version-label">1.9.1</span>
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
<span id="global-search-close">x</span>
</form>
</div>
<div class="trigger">
<div id="global-search-mobile-border">
<div id="gs-search-icon-mobile"></div>
<input id="global-search-mobile" placeholder="Search..." type="text"/>
<div id="global-search-dropdown-container-mobile">
<button class="gs-current-version-mobile btn" type="button" data-toggle="dropdown">
<svg class="gs-dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true">
<path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path>
</svg>
</button>
<ul class="gs-opt-group gs-version-dropdown-mobile">
<li class="gs-opt gs-versions">master</li>
<li class="gs-opt gs-versions active">1.9.1</li>
<li class="gs-opt gs-versions">1.8.0</li>
<li class="gs-opt gs-versions">1.7.0</li>
<li class="gs-opt gs-versions">1.6.0</li>
<li class="gs-opt gs-versions">1.5.0</li>
<li class="gs-opt gs-versions">1.4.1</li>
<li class="gs-opt gs-versions">1.3.1</li>
<li class="gs-opt gs-versions">1.2.1</li>
<li class="gs-opt gs-versions">1.1.0</li>
<li class="gs-opt gs-versions">1.0.0</li>
<li class="gs-opt gs-versions">0.12.1</li>
<li class="gs-opt gs-versions">0.11.0</li>
</ul>
</div>
</div>
<a class="page-link" href="/versions/1.9.1/get_started">Get Started</a>
<a class="page-link" href="/versions/1.9.1/features">Features</a>
<a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a>
<a class="page-link" href="/versions/1.9.1/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/1.9.1/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">1.9.1
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a href="/">master</a>
<a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a>
<a href="/versions/1.8.0/">1.8.0</a>
<a href="/versions/1.7.0/">1.7.0</a>
<a href="/versions/1.6.0/">1.6.0</a>
<a href="/versions/1.5.0/">1.5.0</a>
<a href="/versions/1.4.1/">1.4.1</a>
<a href="/versions/1.3.1/">1.3.1</a>
<a href="/versions/1.2.1/">1.2.1</a>
<a href="/versions/1.1.0/">1.1.0</a>
<a href="/versions/1.0.0/">1.0.0</a>
<a href="/versions/0.12.1/">0.12.1</a>
<a href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<main class="page-content" aria-label="Content">
<script>
</script>
<article class="post">
<header class="post-header wrapper">
<h1 class="post-title">Model API *Deprecated*</h1>
<h3></h3></header>
<div class="post-content">
<div class="wrapper">
<div class="row">
<div class="col-3 docs-side-bar">
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<div class="docs-card docs-side">
<ul>
<div class="docs-action-btn">
<a href="/versions/1.9.1/api/scala.html"> <img src="/versions/1.9.1/assets/img/compass.svg"
class="docs-logo-docs">Scala Guide <span
class="span-accented"></span></a>
</div>
<div class="docs-action-btn">
<a href="/versions/1.9.1/api/scala/docs/tutorials"> <img
src="/versions/1.9.1/assets/img/video-tutorial.svg" class="docs-logo-docs">Scala
Tutorials <span class="span-accented"></span></a>
</div>
<div class="docs-action-btn">
<a href="/versions/1.9.1/api/scala/docs/api"> <img src="/versions/1.9.1/assets/img/api.svg"
class="docs-logo-docs">Scala API Reference
<span class="span-accented"></span></a>
</div>
<!-- Let's show the list of tutorials -->
<br>
<h3>Tutorials</h3>
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/char_lstm">Char-LSTM</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/infer">Infer API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/io">Data Loading API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/kvstore">KVStore API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/mnist">MNIST Example</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/model">Model API *Deprecated*</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/module">Module API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/mxnet_scala_on_intellij">Scala on IntelliJ</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/ndarray">NDArray</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/symbol">Symbol API</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<li><a href="/versions/1.9.1/api/scala/docs/tutorials/symbol_in_pictures">Symbol in Pictures</a></li>
<!-- page-category -->
<!-- resource-p -->
<!-- page-category -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- resource-p -->
<!-- page -->
</ul>
</div>
<div class="col-9">
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
<h1 id="mxnet-scala-model-api">MXNet Scala Model API</h1>
<p>The model API provides a simplified way to train neural networks using common best practices.
It&#39;s a thin wrapper built on top of the <a href="ndarray">ndarray</a> and <a href="symbol">symbolic</a>
modules that make neural network training easy.</p>
<p>Topics:</p>
<ul>
<li><a href="#train-the-model">Train the Model</a></li>
<li><a href="#save-the-model">Save the Model</a></li>
<li><a href="#periodic-checkpointing">Periodic Checkpoint</a></li>
<li><a href="#use-multiple-devices">Multiple Devices</a></li>
<li><a href="/versions/1.9.1/api/scala/docs/api/#org.apache.mxnet.Model">Model API Reference</a></li>
</ul>
<h2 id="train-the-model">Train the Model</h2>
<p>To train a model, perform two steps: configure the model using the symbol parameter,
then call <code>model.Feedforward.create</code> to create the model.
The following example creates a two-layer neural network.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="c1">// configure a two layer neuralnetwork
</span> <span class="k">val</span> <span class="nv">data</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">Variable</span><span class="o">(</span><span class="s">"data"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">fc1</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">FullyConnected</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">num_hidden</span> <span class="k">=</span> <span class="mi">128</span><span class="o">,</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"fc1"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">act1</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">Activation</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">fc1</span><span class="o">),</span> <span class="s">"relu"</span><span class="o">,</span> <span class="s">"relu1"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">fc2</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">FullyConnected</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">act1</span><span class="o">),</span> <span class="n">num_hidden</span> <span class="k">=</span> <span class="mi">64</span><span class="o">,</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"fc2"</span><span class="o">)</span>
<span class="k">val</span> <span class="nv">softmax</span> <span class="k">=</span> <span class="nv">Symbol</span><span class="o">.</span><span class="py">api</span><span class="o">.</span><span class="py">SoftmaxOutput</span><span class="o">(</span><span class="nc">Some</span><span class="o">(</span><span class="n">fc2</span><span class="o">),</span> <span class="n">name</span> <span class="k">=</span> <span class="s">"sm"</span><span class="o">)</span>
<span class="c1">// Construct the FeedForward model and fit on the input training data
</span> <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="nv">FeedForward</span><span class="o">.</span><span class="py">newBuilder</span><span class="o">(</span><span class="n">softmax</span><span class="o">)</span>
<span class="o">.</span><span class="py">setContext</span><span class="o">(</span><span class="nv">Context</span><span class="o">.</span><span class="py">cpu</span><span class="o">())</span>
<span class="o">.</span><span class="py">setNumEpoch</span><span class="o">(</span><span class="n">num_epoch</span><span class="o">)</span>
<span class="o">.</span><span class="py">setOptimizer</span><span class="o">(</span><span class="k">new</span> <span class="nc">SGD</span><span class="o">(</span><span class="n">learningRate</span> <span class="k">=</span> <span class="mf">0.01f</span><span class="o">,</span> <span class="n">momentum</span> <span class="k">=</span> <span class="mf">0.9f</span><span class="o">,</span> <span class="n">wd</span> <span class="k">=</span> <span class="mf">0.0001f</span><span class="o">))</span>
<span class="o">.</span><span class="py">setTrainData</span><span class="o">(</span><span class="n">trainDataIter</span><span class="o">)</span>
<span class="o">.</span><span class="py">setEvalData</span><span class="o">(</span><span class="n">valDataIter</span><span class="o">)</span>
<span class="o">.</span><span class="py">build</span><span class="o">()</span>
</code></pre></div>
<p>You can also use the <code>scikit-learn-style</code> construct and <code>fit</code> function to create a model.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="c1">// create a model using sklearn-style two-step way
</span> <span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">FeedForward</span><span class="o">(</span><span class="n">softmax</span><span class="o">,</span>
<span class="n">numEpoch</span> <span class="k">=</span> <span class="n">numEpochs</span><span class="o">,</span>
<span class="n">argParams</span> <span class="k">=</span> <span class="n">argParams</span><span class="o">,</span>
<span class="n">auxParams</span> <span class="k">=</span> <span class="n">auxParams</span><span class="o">,</span>
<span class="n">beginEpoch</span> <span class="k">=</span> <span class="n">beginEpoch</span><span class="o">,</span>
<span class="n">epochSize</span> <span class="k">=</span> <span class="n">epochSize</span><span class="o">)</span>
<span class="nv">model</span><span class="o">.</span><span class="py">fit</span><span class="o">(</span><span class="n">trainData</span> <span class="k">=</span> <span class="n">train</span><span class="o">)</span>
</code></pre></div>
<p>For more information, see <a href="/versions/1.9.1/api/scala/docs/api/#package">API Reference</a>.</p>
<h2 id="save-the-model">Save the Model</h2>
<p>After the job is done, save your work.
We also provide <code>save</code> and <code>load</code> functions. You can use the <code>load</code> function to load a model checkpoint from a file.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="c1">// checkpoint the model data into file,
</span> <span class="c1">// save a model to modelPrefix-symbol.json and modelPrefix-0100.params
</span> <span class="k">val</span> <span class="nv">modelPrefix</span><span class="k">:</span> <span class="kt">String</span> <span class="o">=</span> <span class="s">"checkpt"</span>
<span class="k">val</span> <span class="nv">num_epoch</span> <span class="k">=</span> <span class="mi">100</span>
<span class="nv">Model</span><span class="o">.</span><span class="py">saveCheckpoint</span><span class="o">(</span><span class="n">modelPrefix</span><span class="o">,</span> <span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="o">,</span> <span class="n">symbol</span><span class="o">,</span> <span class="n">argParams</span><span class="o">,</span> <span class="n">auxStates</span><span class="o">)</span>
<span class="c1">// load model back
</span> <span class="k">val</span> <span class="nv">model_loaded</span> <span class="k">=</span> <span class="nv">FeedForward</span><span class="o">.</span><span class="py">load</span><span class="o">(</span><span class="n">modelPrefix</span><span class="o">,</span> <span class="n">num_epoch</span><span class="o">)</span>
</code></pre></div>
<p>The advantage of these two <code>save</code> and <code>load</code> functions is that they are language agnostic.
You should be able to save and load directly into cloud storage, such as Amazon S3 and HDFS.</p>
<h2 id="periodic-checkpointing">Periodic Checkpointing</h2>
<p>We recommend checkpointing your model after each iteration.
To do this, use <code>EpochEndCallback</code> to add a <code>Model.saveCheckpoint(&lt;parameters&gt;)</code> checkpoint callback to the function after each iteration .</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="c1">// modelPrefix-symbol.json will be saved for symbol.
</span> <span class="c1">// modelPrefix-epoch.params will be saved for parameters.
</span> <span class="c1">// Checkpoint the model into file. Can specify parameters.
</span> <span class="c1">// For more information, check API doc.
</span> <span class="k">val</span> <span class="nv">modelPrefix</span><span class="k">:</span> <span class="kt">String</span> <span class="o">=</span> <span class="s">"checkpt"</span>
<span class="k">val</span> <span class="nv">checkpoint</span><span class="k">:</span> <span class="kt">EpochEndCallback</span> <span class="o">=</span>
<span class="nf">if</span> <span class="o">(</span><span class="n">modelPrefix</span> <span class="o">==</span> <span class="kc">null</span><span class="o">)</span> <span class="kc">null</span>
<span class="k">else</span> <span class="k">new</span> <span class="nc">EpochEndCallback</span> <span class="o">{</span>
<span class="k">override</span> <span class="k">def</span> <span class="nf">invoke</span><span class="o">(</span><span class="n">epoch</span><span class="k">:</span> <span class="kt">Int</span><span class="o">,</span> <span class="n">symbol</span><span class="k">:</span> <span class="kt">Symbol</span><span class="o">,</span>
<span class="n">argParams</span><span class="k">:</span> <span class="kt">Map</span><span class="o">[</span><span class="kt">String</span>, <span class="kt">NDArray</span><span class="o">],</span>
<span class="n">auxStates</span><span class="k">:</span> <span class="kt">Map</span><span class="o">[</span><span class="kt">String</span>, <span class="kt">NDArray</span><span class="o">])</span><span class="k">:</span> <span class="kt">Unit</span> <span class="o">=</span> <span class="o">{</span>
<span class="nv">Model</span><span class="o">.</span><span class="py">saveCheckpoint</span><span class="o">(</span><span class="n">modelPrefix</span><span class="o">,</span> <span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="o">,</span> <span class="n">symbol</span><span class="o">,</span> <span class="n">argParams</span><span class="o">,</span> <span class="n">auxParams</span><span class="o">)</span>
<span class="o">}</span>
<span class="o">}</span>
<span class="c1">// Load model checkpoint from file. Returns symbol, argParams, auxParams.
</span> <span class="nf">val</span> <span class="o">(</span><span class="k">_</span><span class="o">,</span> <span class="n">argParams</span><span class="o">,</span> <span class="k">_</span><span class="o">)</span> <span class="k">=</span> <span class="nv">Model</span><span class="o">.</span><span class="py">loadCheckpoint</span><span class="o">(</span><span class="n">modelPrefix</span><span class="o">,</span> <span class="n">num_epoch</span><span class="o">)</span>
</code></pre></div>
<p>You can load the model checkpoint later using <code>Model.loadCheckpoint(modelPrefix, num_epoch)</code>.</p>
<h2 id="use-multiple-devices">Use Multiple Devices</h2>
<p>Set <code>ctx</code> to the list of devices that you want to train on. You can create a list of devices in any way you want.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"> <span class="k">val</span> <span class="nv">devices</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">(</span><span class="nv">Context</span><span class="o">.</span><span class="py">gpu</span><span class="o">(</span><span class="mi">0</span><span class="o">),</span> <span class="nv">Context</span><span class="o">.</span><span class="py">gpu</span><span class="o">(</span><span class="mi">1</span><span class="o">))</span>
<span class="k">val</span> <span class="nv">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">FeedForward</span><span class="o">(</span><span class="n">ctx</span> <span class="k">=</span> <span class="n">devices</span><span class="o">,</span>
<span class="n">symbol</span> <span class="k">=</span> <span class="n">network</span><span class="o">,</span>
<span class="n">numEpoch</span> <span class="k">=</span> <span class="n">numEpochs</span><span class="o">,</span>
<span class="n">optimizer</span> <span class="k">=</span> <span class="n">optimizer</span><span class="o">,</span>
<span class="n">epochSize</span> <span class="k">=</span> <span class="n">epochSize</span><span class="o">,</span>
<span class="o">...)</span>
</code></pre></div>
<p>Training occurs in parallel on the GPUs that you specify.</p>
<h2 id="next-steps">Next Steps</h2>
<ul>
<li>See <a href="symbol">Symbolic API</a> for operations on NDArrays that assemble neural networks from layers.</li>
<li>See <a href="io">IO Data Loading API</a> for parsing and loading data.</li>
<li>See <a href="ndarray">NDArray API</a> for vector/matrix/tensor operations.</li>
<li>See <a href="kvstore">KVStore API</a> for multi-GPU and multi-host distributed training.</li>
</ul>
</div>
</div>
</div>
</div>
</article>
</main><footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a href="/versions/1.9.1/community/contribute#mxnet-dev-communications">Mailing lists</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
<li><a href="https://github.com/apache/mxnet/labels/Roadmap">Github Roadmap</a></li>
<li><a href="https://medium.com/apache-mxnet">Blog</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/versions/1.9.1/community/contribute">Contribute</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="/versions/1.9.1/assets/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="/versions/1.9.1/assets/img/asf_logo.svg" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
</p><p>"Copyright © 2017-2022, The Apache Software Foundation. Licensed under the Apache License, Version 2.0. Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>