blob: 45413dbcdc0a74c42553d592511c46ddc81c63e6 [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charset="utf-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><meta name="generator" content="rustdoc"><meta name="description" content="Neural Network module"><meta name="keywords" content="rust, rustlang, rust-lang, nnet"><title>rusty_machine::learning::nnet - Rust</title><link rel="preload" as="font" type="font/woff2" crossorigin href="../../../SourceSerif4-Regular.ttf.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../../FiraSans-Regular.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../../FiraSans-Medium.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../../SourceCodePro-Regular.ttf.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../../SourceSerif4-Bold.ttf.woff2"><link rel="preload" as="font" type="font/woff2" crossorigin href="../../../SourceCodePro-Semibold.ttf.woff2"><link rel="stylesheet" href="../../../normalize.css"><link rel="stylesheet" href="../../../rustdoc.css" id="mainThemeStyle"><link rel="stylesheet" href="../../../ayu.css" disabled><link rel="stylesheet" href="../../../dark.css" disabled><link rel="stylesheet" href="../../../light.css" id="themeStyle"><script id="default-settings" ></script><script src="../../../storage.js"></script><script defer src="../../../main.js"></script><noscript><link rel="stylesheet" href="../../../noscript.css"></noscript><link rel="alternate icon" type="image/png" href="../../../favicon-16x16.png"><link rel="alternate icon" type="image/png" href="../../../favicon-32x32.png"><link rel="icon" type="image/svg+xml" href="../../../favicon.svg"></head><body class="rustdoc mod"><!--[if lte IE 11]><div class="warning">This old browser is unsupported and will most likely display funky things.</div><![endif]--><nav class="mobile-topbar"><button class="sidebar-menu-toggle">&#9776;</button><a class="sidebar-logo" href="../../../rusty_machine/index.html"><div class="logo-container"><img class="rust-logo" src="../../../rust-logo.svg" alt="logo"></div></a><h2></h2></nav><nav class="sidebar"><a class="sidebar-logo" href="../../../rusty_machine/index.html"><div class="logo-container"><img class="rust-logo" src="../../../rust-logo.svg" alt="logo"></div></a><h2 class="location"><a href="#">Module nnet</a></h2><div class="sidebar-elems"><section><ul class="block"><li><a href="#modules">Modules</a></li><li><a href="#structs">Structs</a></li><li><a href="#traits">Traits</a></li></ul></section></div></nav><main><div class="width-limiter"><nav class="sub"><form class="search-form"><div class="search-container"><span></span><input class="search-input" name="search" autocomplete="off" spellcheck="false" placeholder="Click or press ‘S’ to search, ‘?’ for more options…" type="search"><div id="help-button" title="help" tabindex="-1"><a href="../../../help.html">?</a></div><div id="settings-menu" tabindex="-1"><a href="../../../settings.html" title="settings"><img width="22" height="22" alt="Change settings" src="../../../wheel.svg"></a></div></div></form></nav><section id="main-content" class="content"><div class="main-heading"><h1 class="fqn">Module <a href="../../index.html">rusty_machine</a>::<wbr><a href="../index.html">learning</a>::<wbr><a class="mod" href="#">nnet</a><button id="copy-path" onclick="copy_path(this)" title="Copy item path to clipboard"><img src="../../../clipboard.svg" width="19" height="18" alt="Copy item path"></button></h1><span class="out-of-band"><a class="srclink" href="../../../src/rusty_machine/learning/nnet/mod.rs.html#1-588">source</a> · <a id="toggle-all-docs" href="javascript:void(0)" title="collapse all docs">[<span class="inner">&#x2212;</span>]</a></span></div><details class="rustdoc-toggle top-doc" open><summary class="hideme"><span>Expand description</span></summary><div class="docblock"><p>Neural Network module</p>
<p>Contains implementation of simple feed forward neural network.</p>
<h2 id="usage"><a href="#usage">Usage</a></h2>
<div class="example-wrap"><pre class="rust rust-example-rendered"><code><span class="kw">use </span>rusty_machine::learning::nnet::{NeuralNet, BCECriterion};
<span class="kw">use </span>rusty_machine::learning::toolkit::regularization::Regularization;
<span class="kw">use </span>rusty_machine::learning::toolkit::activ_fn::Sigmoid;
<span class="kw">use </span>rusty_machine::learning::optim::grad_desc::StochasticGD;
<span class="kw">use </span>rusty_machine::linalg::Matrix;
<span class="kw">use </span>rusty_machine::learning::SupModel;
<span class="kw">let </span>inputs = Matrix::new(<span class="number">5</span>,<span class="number">3</span>, <span class="macro">vec!</span>[<span class="number">1.</span>,<span class="number">1.</span>,<span class="number">1.</span>,<span class="number">2.</span>,<span class="number">2.</span>,<span class="number">2.</span>,<span class="number">3.</span>,<span class="number">3.</span>,<span class="number">3.</span>,
<span class="number">4.</span>,<span class="number">4.</span>,<span class="number">4.</span>,<span class="number">5.</span>,<span class="number">5.</span>,<span class="number">5.</span>,]);
<span class="kw">let </span>targets = Matrix::new(<span class="number">5</span>,<span class="number">3</span>, <span class="macro">vec!</span>[<span class="number">1.</span>,<span class="number">0.</span>,<span class="number">0.</span>,<span class="number">0.</span>,<span class="number">1.</span>,<span class="number">0.</span>,<span class="number">0.</span>,<span class="number">0.</span>,<span class="number">1.</span>,
<span class="number">0.</span>,<span class="number">0.</span>,<span class="number">1.</span>,<span class="number">0.</span>,<span class="number">0.</span>,<span class="number">1.</span>]);
<span class="comment">// Set the layer sizes - from input to output
</span><span class="kw">let </span>layers = <span class="kw-2">&amp;</span>[<span class="number">3</span>,<span class="number">5</span>,<span class="number">11</span>,<span class="number">7</span>,<span class="number">3</span>];
<span class="comment">// Choose the BCE criterion with L2 regularization (`lambda=0.1`).
</span><span class="kw">let </span>criterion = BCECriterion::new(Regularization::L2(<span class="number">0.1</span>));
<span class="comment">// We will create a multilayer perceptron and just use the default stochastic gradient descent.
</span><span class="kw">let </span><span class="kw-2">mut </span>model = NeuralNet::mlp(layers, criterion, StochasticGD::default(), Sigmoid);
<span class="comment">// Train the model!
</span>model.train(<span class="kw-2">&amp;</span>inputs, <span class="kw-2">&amp;</span>targets).unwrap();
<span class="kw">let </span>test_inputs = Matrix::new(<span class="number">2</span>,<span class="number">3</span>, <span class="macro">vec!</span>[<span class="number">1.5</span>,<span class="number">1.5</span>,<span class="number">1.5</span>,<span class="number">5.1</span>,<span class="number">5.1</span>,<span class="number">5.1</span>]);
<span class="comment">// And predict new output from the test inputs
</span><span class="kw">let </span>outputs = model.predict(<span class="kw-2">&amp;</span>test_inputs).unwrap();</code></pre></div>
<p>The neural networks are specified via a criterion - similar to
<a href="https://github.com/torch/nn/blob/master/doc/criterion.md">Torch</a>.
The criterions specify a cost function and any regularization.</p>
<p>You can define your own criterion by implementing the <code>Criterion</code>
trait with a concrete <code>CostFunc</code>.</p>
</div></details><h2 id="modules" class="small-section-header"><a href="#modules">Modules</a></h2><div class="item-table"><div class="item-row"><div class="item-left module-item"><a class="mod" href="net_layer/index.html" title="rusty_machine::learning::nnet::net_layer mod">net_layer</a></div><div class="item-right docblock-short">Neural Network Layers</div></div></div><h2 id="structs" class="small-section-header"><a href="#structs">Structs</a></h2><div class="item-table"><div class="item-row"><div class="item-left module-item"><a class="struct" href="struct.BCECriterion.html" title="rusty_machine::learning::nnet::BCECriterion struct">BCECriterion</a></div><div class="item-right docblock-short">The binary cross entropy criterion.</div></div><div class="item-row"><div class="item-left module-item"><a class="struct" href="struct.BaseNeuralNet.html" title="rusty_machine::learning::nnet::BaseNeuralNet struct">BaseNeuralNet</a></div><div class="item-right docblock-short">Base Neural Network struct</div></div><div class="item-row"><div class="item-left module-item"><a class="struct" href="struct.MSECriterion.html" title="rusty_machine::learning::nnet::MSECriterion struct">MSECriterion</a></div><div class="item-right docblock-short">The mean squared error criterion.</div></div><div class="item-row"><div class="item-left module-item"><a class="struct" href="struct.NeuralNet.html" title="rusty_machine::learning::nnet::NeuralNet struct">NeuralNet</a></div><div class="item-right docblock-short">Neural Network Model</div></div></div><h2 id="traits" class="small-section-header"><a href="#traits">Traits</a></h2><div class="item-table"><div class="item-row"><div class="item-left module-item"><a class="trait" href="trait.Criterion.html" title="rusty_machine::learning::nnet::Criterion trait">Criterion</a></div><div class="item-right docblock-short">Criterion for Neural Networks</div></div></div></section></div></main><div id="rustdoc-vars" data-root-path="../../../" data-current-crate="rusty_machine" data-themes="ayu,dark,light" data-resource-suffix="" data-rustdoc-version="1.66.0-nightly (5c8bff74b 2022-10-21)" ></div></body></html>