blob: a65479c601811164a286011ee1805cd89f5629e6 [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>Distributed Training · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="3.0.0.rc1"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="Distributed Training · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://feynmandna.github.io/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://feynmandna.github.io/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://feynmandna.github.io/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://feynmandna.github.io/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://feynmandna.github.io/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>3.0.0.rc1</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/installation" target="_self">Docs</a></li><li class=""><a href="/docs/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class=""><a target="_self"></a></li><li class=""><a href="https://github.com/apache/singa-doc" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i></i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/tensor">Tensor</a></li><li class="navListItem"><a class="navItem" href="/docs/autograd">Autograd</a></li><li class="navListItem"><a class="navItem" href="/docs/graph">Computational Graph</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/dist-train">Distributed Training</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/download-singa">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script>
var coll = document.getElementsByClassName('collapsible');
var checkActiveCategory = true;
for (var i = 0; i < coll.length; i++) {
var links = coll[i].nextElementSibling.getElementsByTagName('*');
if (checkActiveCategory){
for (var j = 0; j < links.length; j++) {
if (links[j].classList.contains('navListItemActive')){
coll[i].nextElementSibling.classList.toggle('hide');
coll[i].childNodes[1].classList.toggle('rotate');
checkActiveCategory = false;
break;
}
}
}
coll[i].addEventListener('click', function() {
var arrow = this.childNodes[1];
arrow.classList.toggle('rotate');
var content = this.nextElementSibling;
content.classList.toggle('hide');
});
}
document.addEventListener('DOMContentLoaded', function() {
createToggler('#navToggler', '#docsNav', 'docsSliderActive');
createToggler('#tocToggler', 'body', 'tocActive');
var headings = document.querySelector('.toc-headings');
headings && headings.addEventListener('click', function(event) {
var el = event.target;
while(el !== headings){
if (el.tagName === 'A') {
document.body.classList.remove('tocActive');
break;
} else{
el = el.parentNode;
}
}
}, false);
function createToggler(togglerSelector, targetSelector, className) {
var toggler = document.querySelector(togglerSelector);
var target = document.querySelector(targetSelector);
if (!toggler) {
return;
}
toggler.onclick = function(event) {
event.preventDefault();
target.classList.toggle(className);
};
}
});
</script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs/dist-train.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">Distributed Training</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>SINGA supports data parallel training across multiple GPUs (on a single node or
across different nodes). The following figure illustrates the data parallel
training:</p>
<p><img src="/docs/assets/MPI.png" alt="MPI.png"></p>
<p>In distributed training, each process (called a worker) runs a training script
over a single GPU. Each process has an individual communication rank. The
training data is partitioned among the workers and the model is replicated on
every worker. In each iteration, the workers read a mini-batch of data (e.g.,
256 images) from its partition and run the BackPropagation algorithm to compute
the gradients of the weights, which are averaged via all-reduce (provided by
<a href="https://developer.nvidia.com/nccl">NCCL</a>) for weight update following
stochastic gradient descent algorithms (SGD).</p>
<p>The all-reduce operation by NCCL can be used to reduce and synchronize the
gradients from different GPUs. Let's consider the training with 4 GPUs as shown
below. Once the gradients from the 4 GPUs are calculated, all-reduce will return
the sum of the gradients over the GPUs and make it available on every GPU. Then
the averaged gradients can be easily calculated.</p>
<p><img src="/docs/assets/AllReduce.png" alt="AllReduce.png"></p>
<h2><a class="anchor" aria-hidden="true" id="usage"></a><a href="#usage" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Usage</h2>
<p>SINGA implements a module called <code>DistOpt</code> (a subclass of <code>Opt</code>) for distributed
training. It wraps a normal SGD optimizer and calls <code>Communicator</code> for gradients
synchronization. The following example illustrates the usage of <code>DistOpt</code> for
training a CNN model over the MNIST dataset. The source code is available
<a href="https://github.com/apache/singa/blob/master/examples/cnn/">here</a>, and there is
a <a href="">Colab notebook</a> for it.</p>
<h3><a class="anchor" aria-hidden="true" id="example-code"></a><a href="#example-code" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Example Code</h3>
<ol>
<li>Define the neural network model:</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>)
self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>)
self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
y = self.conv1(x)
y = autograd.relu(y)
y = self.pooling1(y)
y = self.conv2(y)
y = autograd.relu(y)
y = self.pooling2(y)
y = autograd.flatten(y)
y = self.linear1(y)
y = autograd.relu(y)
y = self.linear2(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-comment"># create model</span>
model = CNN()
</code></pre>
<ol start="2">
<li>Create the <code>DistOpt</code> instance:</li>
</ol>
<pre><code class="hljs css language-python">sgd = opt.SGD(lr=<span class="hljs-number">0.005</span>, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">1e-5</span>)
sgd = opt.DistOpt(sgd)
dev = device.create_cuda_gpu_on(sgd.local_rank)
</code></pre>
<p>Here are some explanations concerning some variables in the code:</p>
<p>(i) <code>dev</code></p>
<p>dev represents the <code>Device</code> instance, where to load data and run the CNN model.</p>
<p>(ii)<code>local_rank</code></p>
<p>Local rank represents the GPU number the current process is using in the same
node. For example, if you are using a node with 2 GPUs, <code>local_rank=0</code> means
that this process is using the first GPU, while <code>local_rank=1</code> means using the
second GPU. Using MPI or multiprocess, you are able to run the same training
script which is only different in the value of <code>local_rank</code>.</p>
<p>(iii)<code>global_rank</code></p>
<p>Rank in global represents the global rank considered all the processes in all
the nodes you are using. Let's consider the case you have 3 nodes and each of
the node has two GPUs, <code>global_rank=0</code> means the process using the 1st GPU at
the 1st node, <code>global_rank=2</code> means the process using the 1st GPU of the 2nd
node, and <code>global_rank=4</code> means the process using the 1st GPU of the 3rd node.</p>
<ol start="3">
<li>Load and partition the training/validation data:</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">data_partition</span><span class="hljs-params">(dataset_x, dataset_y, global_rank, world_size)</span>:</span>
data_per_rank = dataset_x.shape[<span class="hljs-number">0</span>] // world_size
idx_start = global_rank * data_per_rank
idx_end = (global_rank + <span class="hljs-number">1</span>) * data_per_rank
<span class="hljs-keyword">return</span> dataset_x[idx_start:idx_end], dataset_y[idx_start:idx_end]
train_x, train_y, test_x, test_y = load_dataset()
train_x, train_y = data_partition(train_x, train_y,
sgd.global_rank, sgd.world_size)
test_x, test_y = data_partition(test_x, test_y,
sgd.global_rank, sgd.world_size)
</code></pre>
<p>A partition of the dataset is returned for this <code>dev</code>.</p>
<ol start="4">
<li>Initialize and synchronize the model parameters among all workers:</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">synchronize</span><span class="hljs-params">(tensor, dist_opt)</span>:</span>
dist_opt.all_reduce(tensor.data)
tensor /= dist_opt.world_size
<span class="hljs-comment">#Synchronize the initial parameter</span>
tx = tensor.Tensor((batch_size, <span class="hljs-number">1</span>, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
...
out = model.forward(tx)
loss = autograd.softmax_cross_entropy(out, ty)
<span class="hljs-keyword">for</span> p, g <span class="hljs-keyword">in</span> autograd.backward(loss):
synchronize(p, sgd)
</code></pre>
<p>Here, <code>world_size</code> represents the total number of processes in all the nodes you
are using for distributed training.</p>
<ol start="5">
<li>Run BackPropagation and distributed SGD</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> range(max_epoch):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(num_train_batch):
x = train_x[idx[b * batch_size: (b + <span class="hljs-number">1</span>) * batch_size]]
y = train_y[idx[b * batch_size: (b + <span class="hljs-number">1</span>) * batch_size]]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
out = model.forward(tx)
loss = autograd.softmax_cross_entropy(out, ty)
<span class="hljs-comment"># do backpropagation and all-reduce</span>
sgd.backward_and_update(loss)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="execution-instruction"></a><a href="#execution-instruction" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Execution Instruction</h3>
<p>There are two ways to launch the training: MPI or Python multiprocessing.</p>
<h4><a class="anchor" aria-hidden="true" id="python-multiprocessing"></a><a href="#python-multiprocessing" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Python multiprocessing</h4>
<p>It works on a single node with multiple GPUs, where each GPU is one worker.</p>
<ol>
<li>Put all the above training codes in a function</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_mnist_cnn</span><span class="hljs-params">(nccl_id=None, local_rank=None, world_size=None)</span>:</span>
...
</code></pre>
<ol start="2">
<li>Create <code>mnist_multiprocess.py</code></li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">'__main__'</span>:
<span class="hljs-comment"># Generate a NCCL ID to be used for collective communication</span>
nccl_id = singa.NcclIdHolder()
<span class="hljs-comment"># Define the number of GPUs to be used in the training process</span>
world_size = int(sys.argv[<span class="hljs-number">1</span>])
<span class="hljs-comment"># Define and launch the multi-processing</span>
<span class="hljs-keyword">import</span> multiprocessing
process = []
<span class="hljs-keyword">for</span> local_rank <span class="hljs-keyword">in</span> range(<span class="hljs-number">0</span>, world_size):
process.append(multiprocessing.Process(target=train_mnist_cnn,
args=(nccl_id, local_rank, world_size)))
<span class="hljs-keyword">for</span> p <span class="hljs-keyword">in</span> process:
p.start()
</code></pre>
<p>Here are some explanations concerning the variables created above:</p>
<p>(i) <code>nccl_id</code></p>
<p>Note that we need to generate a NCCL ID here to be used for collective
communication, and then pass it to all the processes. The NCCL ID is like a
ticket, where only the processes with this ID can join the all-reduce operation.
(Later if we use MPI, the passing of NCCL ID is not necessary, because the ID is
broadcased by MPI in our code automatically)</p>
<p>(ii) <code>world_size</code></p>
<p>world_size is the number of GPUs you would like to use for training.</p>
<p>(iii) <code>local_rank</code></p>
<p>local_rank determine the local rank of the distributed training and which gpu is
used in the process. In the code above, we used a for loop to run the train
function where the argument local_rank iterates from 0 to world_size. In this
case, different processes can use different GPUs for training.</p>
<p>The arguments for creating the <code>DistOpt</code> instance should be updated as follows</p>
<pre><code class="hljs css language-python">sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size)
</code></pre>
<ol start="3">
<li>Run <code>mnist_multiprocess.py</code></li>
</ol>
<pre><code class="hljs css language-sh">python mnist_multiprocess.py 2
</code></pre>
<p>It results in speed up compared to the single GPU training.</p>
<pre><code class="hljs">Starting Epoch <span class="hljs-number">0</span>:
Training loss = <span class="hljs-number">408.909790</span>, training accuracy = <span class="hljs-number">0.880475</span>
Evaluation accuracy = <span class="hljs-number">0.956430</span>
Starting Epoch <span class="hljs-number">1</span>:
Training loss = <span class="hljs-number">102.396790</span>, training accuracy = <span class="hljs-number">0.967415</span>
Evaluation accuracy = <span class="hljs-number">0.977564</span>
Starting Epoch <span class="hljs-number">2</span>:
Training loss = <span class="hljs-number">69.217010</span>, training accuracy = <span class="hljs-number">0.977915</span>
Evaluation accuracy = <span class="hljs-number">0.981370</span>
Starting Epoch <span class="hljs-number">3</span>:
Training loss = <span class="hljs-number">54.248390</span>, training accuracy = <span class="hljs-number">0.982823</span>
Evaluation accuracy = <span class="hljs-number">0.984075</span>
Starting Epoch <span class="hljs-number">4</span>:
Training loss = <span class="hljs-number">45.213406</span>, training accuracy = <span class="hljs-number">0.985560</span>
Evaluation accuracy = <span class="hljs-number">0.985276</span>
Starting Epoch <span class="hljs-number">5</span>:
Training loss = <span class="hljs-number">38.868435</span>, training accuracy = <span class="hljs-number">0.987764</span>
Evaluation accuracy = <span class="hljs-number">0.986278</span>
Starting Epoch <span class="hljs-number">6</span>:
Training loss = <span class="hljs-number">34.078186</span>, training accuracy = <span class="hljs-number">0.989149</span>
Evaluation accuracy = <span class="hljs-number">0.987881</span>
Starting Epoch <span class="hljs-number">7</span>:
Training loss = <span class="hljs-number">30.138697</span>, training accuracy = <span class="hljs-number">0.990451</span>
Evaluation accuracy = <span class="hljs-number">0.988181</span>
Starting Epoch <span class="hljs-number">8</span>:
Training loss = <span class="hljs-number">26.854443</span>, training accuracy = <span class="hljs-number">0.991520</span>
Evaluation accuracy = <span class="hljs-number">0.988682</span>
Starting Epoch <span class="hljs-number">9</span>:
Training loss = <span class="hljs-number">24.039650</span>, training accuracy = <span class="hljs-number">0.992405</span>
Evaluation accuracy = <span class="hljs-number">0.989083</span>
</code></pre>
<h4><a class="anchor" aria-hidden="true" id="mpi"></a><a href="#mpi" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MPI</h4>
<p>It works for both single node and multiple nodes as long as there are multiple
GPUs.</p>
<ol>
<li>Create <code>mnist_dist.py</code></li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">'__main__'</span>:
train_mnist_cnn()
</code></pre>
<ol start="2">
<li>Generate a hostfile for MPI, e.g. the hostfile below uses 2 processes (i.e.,
2 GPUs) on a single node</li>
</ol>
<pre><code class="hljs css language-txt">localhost:<span class="hljs-number">2</span>
</code></pre>
<ol start="3">
<li>Launch the training via <code>mpiexec</code></li>
</ol>
<pre><code class="hljs css language-sh">mpiexec --hostfile host_file python mnist_dist.py
</code></pre>
<p>It could result in speed up compared to the single GPU training.</p>
<pre><code class="hljs">Starting Epoch <span class="hljs-number">0</span>:
Training loss = <span class="hljs-number">383.969543</span>, training accuracy = <span class="hljs-number">0.886402</span>
Evaluation accuracy = <span class="hljs-number">0.954327</span>
Starting Epoch <span class="hljs-number">1</span>:
Training loss = <span class="hljs-number">97.531479</span>, training accuracy = <span class="hljs-number">0.969451</span>
Evaluation accuracy = <span class="hljs-number">0.977163</span>
Starting Epoch <span class="hljs-number">2</span>:
Training loss = <span class="hljs-number">67.166870</span>, training accuracy = <span class="hljs-number">0.978516</span>
Evaluation accuracy = <span class="hljs-number">0.980769</span>
Starting Epoch <span class="hljs-number">3</span>:
Training loss = <span class="hljs-number">53.369656</span>, training accuracy = <span class="hljs-number">0.983040</span>
Evaluation accuracy = <span class="hljs-number">0.983974</span>
Starting Epoch <span class="hljs-number">4</span>:
Training loss = <span class="hljs-number">45.100403</span>, training accuracy = <span class="hljs-number">0.985777</span>
Evaluation accuracy = <span class="hljs-number">0.986078</span>
Starting Epoch <span class="hljs-number">5</span>:
Training loss = <span class="hljs-number">39.330826</span>, training accuracy = <span class="hljs-number">0.987447</span>
Evaluation accuracy = <span class="hljs-number">0.987179</span>
Starting Epoch <span class="hljs-number">6</span>:
Training loss = <span class="hljs-number">34.655270</span>, training accuracy = <span class="hljs-number">0.988799</span>
Evaluation accuracy = <span class="hljs-number">0.987780</span>
Starting Epoch <span class="hljs-number">7</span>:
Training loss = <span class="hljs-number">30.749735</span>, training accuracy = <span class="hljs-number">0.989984</span>
Evaluation accuracy = <span class="hljs-number">0.988281</span>
Starting Epoch <span class="hljs-number">8</span>:
Training loss = <span class="hljs-number">27.422146</span>, training accuracy = <span class="hljs-number">0.991319</span>
Evaluation accuracy = <span class="hljs-number">0.988582</span>
Starting Epoch <span class="hljs-number">9</span>:
Training loss = <span class="hljs-number">24.548153</span>, training accuracy = <span class="hljs-number">0.992171</span>
Evaluation accuracy = <span class="hljs-number">0.988682</span>
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="optimizations-for-distributed-training"></a><a href="#optimizations-for-distributed-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Optimizations for Distributed Training</h2>
<p>SINGA provides multiple optimization strategies for distributed training to
reduce the communication cost. Refer to the API for <code>DistOpt</code> for the
configuration of each strategy.</p>
<h3><a class="anchor" aria-hidden="true" id="no-optimizations"></a><a href="#no-optimizations" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>No Optimizations</h3>
<pre><code class="hljs css language-python">sgd.backward_and_update(loss)
</code></pre>
<p><code>loss</code> is the output tensor from the loss function, e.g., cross-entropy for
classification tasks.</p>
<h3><a class="anchor" aria-hidden="true" id="half-precision-gradients"></a><a href="#half-precision-gradients" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Half-precision Gradients</h3>
<pre><code class="hljs css language-python">sgd.backward_and_update_half(loss)
</code></pre>
<p>It converts each gradient value to 16-bit representation (i.e., half-precision)
before calling all-reduce.</p>
<h3><a class="anchor" aria-hidden="true" id="partial-synchronization"></a><a href="#partial-synchronization" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Partial Synchronization</h3>
<pre><code class="hljs css language-python">sgd.backward_and_partial_update(loss)
</code></pre>
<p>In each iteration, every rank do the local sgd update. Then, only a chunk of
parameters are averaged for synchronization, which saves the communication cost.
The chunk size is configured when creating the <code>DistOpt</code> instance.</p>
<h3><a class="anchor" aria-hidden="true" id="gradient-sparsification"></a><a href="#gradient-sparsification" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Gradient Sparsification</h3>
<pre><code class="hljs css language-python">sgd.backward_and_sparse_update(loss)
</code></pre>
<p>It applies sparsification schemes to select a subset of gradients for
all-reduce. There are two scheme:</p>
<ul>
<li>The top-K largest elements are selected. spars is the portion (0 - 1) of total
elements selected.</li>
</ul>
<pre><code class="hljs css language-python">sgd.backward_and_sparse_update(loss = loss, spars = spars, topK = <span class="hljs-literal">True</span>)
</code></pre>
<ul>
<li>All gradients whose absolute value are larger than predefined threshold spars
are selected.</li>
</ul>
<pre><code class="hljs css language-python">sgd.backward_and_sparse_update(loss = loss, spars = spars, topK = <span class="hljs-literal">False</span>)
</code></pre>
<p>The hyper-parameters are configured when creating the <code>DistOpt</code> instance.</p>
<h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
<p>This section is mainly for developers who want to know how the code in
distribute module is implemented.</p>
<h3><a class="anchor" aria-hidden="true" id="c-interface-for-nccl-communicator"></a><a href="#c-interface-for-nccl-communicator" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>C interface for NCCL communicator</h3>
<p>Firstly, the communication layer is written in C language
<a href="https://github.com/apache/singa/blob/master/src/io/communicator.cc">communicator.cc</a>.
It applies the NCCL library for collective communication.</p>
<p>There are two constructors for the communicator, one for MPI and another for
multiprocess.</p>
<p>(i) Constructor using MPI</p>
<p>The constructor first obtains the global rank and the world size first, and
calculate the local rank. Then, rank 0 generates a NCCL ID and broadcast it to
every rank. After that, it calls the setup function to initialize the NCCL
communicator, cuda streams, and buffers.</p>
<p>(ii) Constructor using Python multiprocess</p>
<p>The constructor first obtains the rank, the world size, and the NCCL ID from the
input argument. After that, it calls the setup function to initialize the NCCL
communicator, cuda streams, and buffers.</p>
<p>After the initialization, it provides the all-reduce functionality to
synchronize the model parameters or gradients. For instance, synch takes a input
tensor and perform all-reduce through the NCCL routine. After we call synch, it
is necessary to call wait function to wait for the all-reduce operation to be
completed.</p>
<h3><a class="anchor" aria-hidden="true" id="python-interface-for-distopt"></a><a href="#python-interface-for-distopt" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Python interface for DistOpt</h3>
<p>Then, the python interface provide a
<a href="https://github.com/apache/singa/blob/master/python/singa/opt.py">DistOpt</a> class
to wrap an
<a href="https://github.com/apache/singa/blob/master/python/singa/opt.py">optimizer</a>
object to perform distributed training based on MPI or multiprocessing. During
the initialization, it creates a NCCL communicator object (from the C interface
as mentioned in the subsection above). Then, this communicator object is used
for every all-reduce operations in DistOpt.</p>
<p>In MPI or multiprocess, each process has an individual rank, which gives
information of which GPU the individual process is using. The training data is
partitioned, so that each process can evaluate the sub-gradient based on the
partitioned training data. Once the sub-gradient is calculated on each
processes, the overall stochastic gradient is obtained by all-reducing the
sub-gradients evaluated by all processes.</p>
</span></div></article></div><div class="docLastUpdate"><em>Last updated on 4/9/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/graph"><span class="arrow-prev"></span><span>Computational Graph</span></a><a class="docs-next button" href="/docs/download-singa"><span>Download SINGA</span><span class="arrow-next"></span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#usage">Usage</a><ul class="toc-headings"><li><a href="#example-code">Example Code</a></li><li><a href="#execution-instruction">Execution Instruction</a></li></ul></li><li><a href="#optimizations-for-distributed-training">Optimizations for Distributed Training</a><ul class="toc-headings"><li><a href="#no-optimizations">No Optimizations</a></li><li><a href="#half-precision-gradients">Half-precision Gradients</a></li><li><a href="#partial-synchronization">Partial Synchronization</a></li><li><a href="#gradient-sparsification">Gradient Sparsification</a></li></ul></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#c-interface-for-nccl-communicator">C interface for NCCL communicator</a></li><li><a href="#python-interface-for-distopt">Python interface for DistOpt</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/#">API Reference (coming soon)</a><a href="/docs/model-zoo-cnn-cifar10">Model Zoo</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/news">SINGA News</a><a href="https://github.com/apache/singa-doc">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa-doc" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script></body></html>