blob: 81b67d880acc67d3982710ee7d98e1d08e6e6b90 [file] [log] [blame]
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Train-One-Batch &mdash; incubator-singa 0.3.0 documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="top" title="incubator-singa 0.3.0 documentation" href="../index.html"/>
<script src="../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search">
<a href="../index.html" class="icon icon-home"> incubator-singa
<img src="../_static/singa.png" class="logo" />
</a>
<div class="version">
0.3.0
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../downloads.html">Download SINGA</a></li>
<li class="toctree-l1"><a class="reference internal" href="index.html">Documentation</a></li>
</ul>
<p class="caption"><span class="caption-text">Development</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../develop/schedule.html">Development Schedule</a></li>
<li class="toctree-l1"><a class="reference internal" href="../develop/how-contribute.html">How to Contribute to SINGA</a></li>
<li class="toctree-l1"><a class="reference internal" href="../develop/contribute-code.html">How to Contribute Code</a></li>
<li class="toctree-l1"><a class="reference internal" href="../develop/contribute-docs.html">How to Contribute Documentation</a></li>
</ul>
<p class="caption"><span class="caption-text">Community</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../community/source-repository.html">Source Repository</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/mail-lists.html">Project Mailing Lists</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/issue-tracking.html">Issue Tracking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/team-list.html">The SINGA Team</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">incubator-singa</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html">Docs</a> &raquo;</li>
<li>Train-One-Batch</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="train-one-batch">
<span id="train-one-batch"></span><h1>Train-One-Batch<a class="headerlink" href="#train-one-batch" title="Permalink to this headline"></a></h1>
<hr class="docutils" />
<p>For each SGD iteration, every worker calls the <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function to
compute gradients of parameters associated with local layers (i.e., layers
dispatched to it). SINGA has implemented two algorithms for the
<code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function. Users select the corresponding algorithm for
their model in the configuration.</p>
<div class="section" id="basic-user-guide">
<span id="basic-user-guide"></span><h2>Basic user guide<a class="headerlink" href="#basic-user-guide" title="Permalink to this headline"></a></h2>
<div class="section" id="back-propagation">
<span id="back-propagation"></span><h3>Back-propagation<a class="headerlink" href="#back-propagation" title="Permalink to this headline"></a></h3>
<p><a class="reference external" href="http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf">BP algorithm</a> is used for
computing gradients of feed-forward models, e.g., <a class="reference external" href="cnn.html">CNN</a>
and <a class="reference external" href="mlp.html">MLP</a>, and <a class="reference external" href="rnn.html">RNN</a> models in SINGA.</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="c1"># in job.conf</span>
<span class="n">alg</span><span class="p">:</span> <span class="n">kBP</span>
</pre></div>
</div>
<p>To use the BP algorithm for the <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function, users just simply
configure the <code class="docutils literal"><span class="pre">alg</span></code> field with <code class="docutils literal"><span class="pre">kBP</span></code>. If a neural net contains user-defined
layers, these layers must be implemented properly be to consistent with the
implementation of the BP algorithm in SINGA (see below).</p>
</div>
<div class="section" id="contrastive-divergence">
<span id="contrastive-divergence"></span><h3>Contrastive Divergence<a class="headerlink" href="#contrastive-divergence" title="Permalink to this headline"></a></h3>
<p><a class="reference external" href="http://www.cs.toronto.edu/~fritz/absps/nccd.pdf">CD algorithm</a> is used for
computing gradients of energy models like RBM.</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="c1"># job.conf</span>
<span class="n">alg</span><span class="p">:</span> <span class="n">kCD</span>
<span class="n">cd_conf</span> <span class="p">{</span>
<span class="n">cd_k</span><span class="p">:</span> <span class="mi">2</span>
<span class="p">}</span>
</pre></div>
</div>
<p>To use the CD algorithm for the <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function, users just configure
the <code class="docutils literal"><span class="pre">alg</span></code> field to <code class="docutils literal"><span class="pre">kCD</span></code>. Uses can also configure the Gibbs sampling steps in
the CD algorthm through the <code class="docutils literal"><span class="pre">cd_k</span></code> field. By default, it is set to 1.</p>
</div>
</div>
<div class="section" id="advanced-user-guide">
<span id="advanced-user-guide"></span><h2>Advanced user guide<a class="headerlink" href="#advanced-user-guide" title="Permalink to this headline"></a></h2>
<div class="section" id="implementation-of-bp">
<span id="implementation-of-bp"></span><h3>Implementation of BP<a class="headerlink" href="#implementation-of-bp" title="Permalink to this headline"></a></h3>
<p>The BP algorithm is implemented in SINGA following the below pseudo code,</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">BPTrainOnebatch</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">net</span><span class="p">)</span> <span class="p">{</span>
<span class="o">//</span> <span class="n">forward</span> <span class="n">propagate</span>
<span class="n">foreach</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">net</span><span class="o">.</span><span class="n">local_layers</span><span class="p">()</span> <span class="p">{</span>
<span class="k">if</span> <span class="n">IsBridgeDstLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">recv</span> <span class="n">data</span> <span class="kn">from</span> <span class="nn">the</span> <span class="n">src</span> <span class="n">layer</span> <span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">e</span><span class="o">.</span><span class="p">,</span> <span class="n">BridgeSrcLayer</span><span class="p">)</span>
<span class="n">foreach</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">layer</span><span class="o">.</span><span class="n">params</span><span class="p">()</span>
<span class="n">Collect</span><span class="p">(</span><span class="n">param</span><span class="p">)</span> <span class="o">//</span> <span class="n">recv</span> <span class="n">response</span> <span class="kn">from</span> <span class="nn">servers</span> <span class="k">for</span> <span class="n">last</span> <span class="n">update</span>
<span class="n">layer</span><span class="o">.</span><span class="n">ComputeFeature</span><span class="p">(</span><span class="n">kForward</span><span class="p">)</span>
<span class="k">if</span> <span class="n">IsBridgeSrcLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">send</span> <span class="n">layer</span><span class="o">.</span><span class="n">data_</span> <span class="n">to</span> <span class="n">dst</span> <span class="n">layer</span>
<span class="p">}</span>
<span class="o">//</span> <span class="n">backward</span> <span class="n">propagate</span>
<span class="n">foreach</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">reverse</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">local_layers</span><span class="p">)</span> <span class="p">{</span>
<span class="k">if</span> <span class="n">IsBridgeSrcLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">recv</span> <span class="n">gradient</span> <span class="kn">from</span> <span class="nn">the</span> <span class="n">dst</span> <span class="n">layer</span> <span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">e</span><span class="o">.</span><span class="p">,</span> <span class="n">BridgeDstLayer</span><span class="p">)</span>
<span class="n">recv</span> <span class="n">response</span> <span class="kn">from</span> <span class="nn">servers</span> <span class="k">for</span> <span class="n">last</span> <span class="n">update</span>
<span class="n">layer</span><span class="o">.</span><span class="n">ComputeGradient</span><span class="p">()</span>
<span class="n">foreach</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">layer</span><span class="o">.</span><span class="n">params</span><span class="p">()</span>
<span class="n">Update</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span> <span class="o">//</span> <span class="n">send</span> <span class="n">param</span><span class="o">.</span><span class="n">grad_</span> <span class="n">to</span> <span class="n">servers</span>
<span class="k">if</span> <span class="n">IsBridgeDstLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">send</span> <span class="n">layer</span><span class="o">.</span><span class="n">grad_</span> <span class="n">to</span> <span class="n">src</span> <span class="n">layer</span>
<span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>It forwards features through all local layers (can be checked by layer
partition ID and worker ID) and backwards gradients in the reverse order.
<a class="reference external" href="layer.html#bridgesrclayer--bridgedstlayer">BridgeSrcLayer</a>
(resp. <code class="docutils literal"><span class="pre">BridgeDstLayer</span></code>) will be blocked until the feature (resp.
gradient) from the source (resp. destination) layer comes. Parameter gradients
are sent to servers via <code class="docutils literal"><span class="pre">Update</span></code> function. Updated parameters are collected via
<code class="docutils literal"><span class="pre">Collect</span></code> function, which will be blocked until the parameter is updated.
<a class="reference external" href="param.html">Param</a> objects have versions, which can be used to
check whether the <code class="docutils literal"><span class="pre">Param</span></code> objects have been updated or not.</p>
<p>Since RNN models are unrolled into feed-forward models, users need to implement
the forward propagation in the recurrent layer&#8217;s <code class="docutils literal"><span class="pre">ComputeFeature</span></code> function,
and implement the backward propagation in the recurrent layer&#8217;s <code class="docutils literal"><span class="pre">ComputeGradient</span></code>
function. As a result, the whole <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> runs
<a class="reference external" href="https://en.wikipedia.org/wiki/Backpropagation_through_time">back-propagation through time (BPTT)</a> algorithm.</p>
</div>
<div class="section" id="implementation-of-cd">
<span id="implementation-of-cd"></span><h3>Implementation of CD<a class="headerlink" href="#implementation-of-cd" title="Permalink to this headline"></a></h3>
<p>The CD algorithm is implemented in SINGA following the below pseudo code,</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">CDTrainOneBatch</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">net</span><span class="p">)</span> <span class="p">{</span>
<span class="c1"># positive phase</span>
<span class="n">foreach</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">net</span><span class="o">.</span><span class="n">local_layers</span><span class="p">()</span>
<span class="k">if</span> <span class="n">IsBridgeDstLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">recv</span> <span class="n">positive</span> <span class="n">phase</span> <span class="n">data</span> <span class="kn">from</span> <span class="nn">the</span> <span class="n">src</span> <span class="n">layer</span> <span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">e</span><span class="o">.</span><span class="p">,</span> <span class="n">BridgeSrcLayer</span><span class="p">)</span>
<span class="n">foreach</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">layer</span><span class="o">.</span><span class="n">params</span><span class="p">()</span>
<span class="n">Collect</span><span class="p">(</span><span class="n">param</span><span class="p">)</span> <span class="o">//</span> <span class="n">recv</span> <span class="n">response</span> <span class="kn">from</span> <span class="nn">servers</span> <span class="k">for</span> <span class="n">last</span> <span class="n">update</span>
<span class="n">layer</span><span class="o">.</span><span class="n">ComputeFeature</span><span class="p">(</span><span class="n">kPositive</span><span class="p">)</span>
<span class="k">if</span> <span class="n">IsBridgeSrcLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">send</span> <span class="n">positive</span> <span class="n">phase</span> <span class="n">data</span> <span class="n">to</span> <span class="n">dst</span> <span class="n">layer</span>
<span class="c1"># negative phase</span>
<span class="n">foreach</span> <span class="n">gibbs</span> <span class="ow">in</span> <span class="p">[</span><span class="mf">0.</span><span class="o">..</span><span class="n">layer_proto_</span><span class="o">.</span><span class="n">cd_k</span><span class="p">]</span>
<span class="n">foreach</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">net</span><span class="o">.</span><span class="n">local_layers</span><span class="p">()</span>
<span class="k">if</span> <span class="n">IsBridgeDstLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">recv</span> <span class="n">negative</span> <span class="n">phase</span> <span class="n">data</span> <span class="kn">from</span> <span class="nn">the</span> <span class="n">src</span> <span class="n">layer</span> <span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">e</span><span class="o">.</span><span class="p">,</span> <span class="n">BridgeSrcLayer</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">ComputeFeature</span><span class="p">(</span><span class="n">kPositive</span><span class="p">)</span>
<span class="k">if</span> <span class="n">IsBridgeSrcLayer</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="n">send</span> <span class="n">negative</span> <span class="n">phase</span> <span class="n">data</span> <span class="n">to</span> <span class="n">dst</span> <span class="n">layer</span>
<span class="n">foreach</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">net</span><span class="o">.</span><span class="n">local_layers</span><span class="p">()</span>
<span class="n">layer</span><span class="o">.</span><span class="n">ComputeGradient</span><span class="p">()</span>
<span class="n">foreach</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">layer</span><span class="o">.</span><span class="n">params</span>
<span class="n">Update</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="p">}</span>
</pre></div>
</div>
<p>Parameter gradients are computed after the positive phase and negative phase.</p>
</div>
<div class="section" id="implementing-a-new-algorithm">
<span id="implementing-a-new-algorithm"></span><h3>Implementing a new algorithm<a class="headerlink" href="#implementing-a-new-algorithm" title="Permalink to this headline"></a></h3>
<p>SINGA implements BP and CD by creating two subclasses of
the <a class="reference external" href="../api/classsinga_1_1Worker.html">Worker</a> class:
<a class="reference external" href="../api/classsinga_1_1BPWorker.html">BPWorker</a>&#8216;s <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function implements the BP
algorithm; <a class="reference external" href="../api/classsinga_1_1CDWorker.html">CDWorker</a>&#8216;s <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function implements the CD
algorithm. To implement a new algorithm for the <code class="docutils literal"><span class="pre">TrainOneBatch</span></code> function, users
need to create a new subclass of the <code class="docutils literal"><span class="pre">Worker</span></code>, e.g.,</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">FooWorker</span> <span class="p">:</span> <span class="n">public</span> <span class="n">Worker</span> <span class="p">{</span>
<span class="n">void</span> <span class="n">TrainOneBatch</span><span class="p">(</span><span class="nb">int</span> <span class="n">step</span><span class="p">,</span> <span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">NeuralNet</span><span class="o">&gt;</span> <span class="n">net</span><span class="p">,</span> <span class="n">Metric</span><span class="o">*</span> <span class="n">perf</span><span class="p">)</span> <span class="n">override</span><span class="p">;</span>
<span class="n">void</span> <span class="n">TestOneBatch</span><span class="p">(</span><span class="nb">int</span> <span class="n">step</span><span class="p">,</span> <span class="n">Phase</span> <span class="n">phase</span><span class="p">,</span> <span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">NeuralNet</span><span class="o">&gt;</span> <span class="n">net</span><span class="p">,</span> <span class="n">Metric</span><span class="o">*</span> <span class="n">perf</span><span class="p">)</span> <span class="n">override</span><span class="p">;</span>
<span class="p">};</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">FooWorker</span></code> must implement the above two functions for training one
mini-batch and testing one mini-batch. The <code class="docutils literal"><span class="pre">perf</span></code> argument is for collecting
training or testing performance, e.g., the objective loss or accuracy. It is
passed to the <code class="docutils literal"><span class="pre">ComputeFeature</span></code> function of each layer.</p>
<p>Users can define some fields for users to configure</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="c1"># in user.proto</span>
<span class="n">message</span> <span class="n">FooWorkerProto</span> <span class="p">{</span>
<span class="n">optional</span> <span class="n">int32</span> <span class="n">b</span> <span class="o">=</span> <span class="mi">1</span><span class="p">;</span>
<span class="p">}</span>
<span class="n">extend</span> <span class="n">JobProto</span> <span class="p">{</span>
<span class="n">optional</span> <span class="n">FooWorkerProto</span> <span class="n">foo_conf</span> <span class="o">=</span> <span class="mi">101</span><span class="p">;</span>
<span class="p">}</span>
<span class="c1"># in job.proto</span>
<span class="n">JobProto</span> <span class="p">{</span>
<span class="o">...</span>
<span class="n">extension</span> <span class="mf">101.</span><span class="o">.</span><span class="n">max</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>It is similar as <a class="reference external" href="layer.html#implementing-a-new-layer-subclass">adding configuration fields for a new layer</a>.</p>
<p>To use <code class="docutils literal"><span class="pre">FooWorker</span></code>, users need to register it in the <a class="reference external" href="programming-guide.html">main.cc</a>
and configure the <code class="docutils literal"><span class="pre">alg</span></code> and <code class="docutils literal"><span class="pre">foo_conf</span></code> fields,</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="c1"># in main.cc</span>
<span class="n">const</span> <span class="nb">int</span> <span class="n">kFoo</span> <span class="o">=</span> <span class="mi">3</span><span class="p">;</span> <span class="o">//</span> <span class="n">worker</span> <span class="n">ID</span><span class="p">,</span> <span class="n">must</span> <span class="n">be</span> <span class="n">different</span> <span class="n">to</span> <span class="n">that</span> <span class="n">of</span> <span class="n">CDWorker</span> <span class="ow">and</span> <span class="n">BPWorker</span>
<span class="n">driver</span><span class="o">.</span><span class="n">RegisterWorker</span><span class="o">&lt;</span><span class="n">FooWorker</span><span class="o">&gt;</span><span class="p">(</span><span class="n">kFoo</span><span class="p">);</span>
<span class="c1"># in job.conf</span>
<span class="o">...</span>
<span class="n">alg</span><span class="p">:</span> <span class="mi">3</span>
<span class="p">[</span><span class="n">foo_conf</span><span class="p">]</span> <span class="p">{</span>
<span class="n">b</span> <span class="o">=</span> <span class="mi">4</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
</div>
</div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2016 The Apache Software Foundation. All rights reserved. Apache Singa, Apache, the Apache feather logo, and the Apache Singa project logos are trademarks of The Apache Software Foundation. All other marks mentioned may be trademarks or registered trademarks of their respective owners..
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT:'../',
VERSION:'0.3.0',
COLLAPSE_INDEX:false,
FILE_SUFFIX:'.html',
HAS_SOURCE: true
};
</script>
<script type="text/javascript" src="../_static/jquery.js"></script>
<script type="text/javascript" src="../_static/underscore.js"></script>
<script type="text/javascript" src="../_static/doctools.js"></script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.StickyNav.enable();
});
</script>
<div class="rst-versions shift-up" data-toggle="rst-versions" role="note" aria-label="versions">
<img src="../_static/apache.jpg">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> incubator-singa </span>
v: 0.3.0
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
<dl>
<dt>Languages</dt>
<dd><a href="../../en/index.html">English</a></dd>
<dd><a href="../../zh/index.html">中文</a></dd>
<dd><a href="../../jp/index.html">日本語</a></dd>
<dd><a href="../../kr/index.html">한국어</a></dd>
</dl>
</div>
</div>
<a href="https://github.com/apache/incubator-singa">
<img style="position: absolute; top: 0; right: 0; border: 0; z-index: 10000;"
src="https://s3.amazonaws.com/github/ribbons/forkme_right_orange_ff7600.png"
alt="Fork me on GitHub">
</a>
</body>
</html>