blob: 3a78135269d7e2b6bf4cb0e626807046f34d2463 [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>Model · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="3.2.0.rc1"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="Model · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://singa.apache.org/"/><meta property="og:description" content="&lt;!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://singa.apache.org/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.css"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://singa.apache.org/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://singa.apache.org/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Source+Sans+Pro:400,400i,700"/><link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Baloo+Paaji+2&amp;family=Source+Sans+Pro:wght@200;300&amp;display=swap"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>3.2.0.rc1</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class="siteNavGroupActive"><a href="/docs/3.2.0.rc1/installation" target="_self">Docs</a></li><li class=""><a href="/docs/3.2.0.rc1/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class="navSearchWrapper reactNavSearchWrapper"><input type="text" id="search_input_react" placeholder="Search" title="Search"/></li><li class=""><a href="https://github.com/apache/singa" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="docsNavContainer" id="docsNav"><nav class="toc"><div class="toggleNav"><section class="navWrapper wrapper"><div class="navBreadcrumb wrapper"><div class="navToggle" id="navToggler"><div class="hamburger-menu"><div class="line1"></div><div class="line2"></div><div class="line3"></div></div></div><h2><i></i><span>Guides</span></h2><div class="tocToggler" id="tocToggler"><i class="icon-toc"></i></div></div><div class="navGroups"><div class="navGroup"><h3 class="navGroupCategoryTitle">Getting Started</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/installation">Installation</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/software-stack">Software Stack</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/examples">Examples</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Guides</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/device">Device</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/tensor">Tensor</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/autograd">Autograd</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/optimizer">Optimizer</a></li><li class="navListItem navListItemActive"><a class="navItem" href="/docs/3.2.0.rc1/graph">Model</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/onnx">ONNX</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/dist-train">Distributed Training</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/time-profiling">Time Profiling</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/half-precision">Half Precision</a></li></ul></div><div class="navGroup"><h3 class="navGroupCategoryTitle">Development</h3><ul class=""><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/downloads">Download SINGA</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/build">Build SINGA from Source</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/contribute-code">How to Contribute Code</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/contribute-docs">How to Contribute to Documentation</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/how-to-release">How to Prepare a Release</a></li><li class="navListItem"><a class="navItem" href="/docs/3.2.0.rc1/git-workflow">Git Workflow</a></li></ul></div></div></section></div><script>
var coll = document.getElementsByClassName('collapsible');
var checkActiveCategory = true;
for (var i = 0; i < coll.length; i++) {
var links = coll[i].nextElementSibling.getElementsByTagName('*');
if (checkActiveCategory){
for (var j = 0; j < links.length; j++) {
if (links[j].classList.contains('navListItemActive')){
coll[i].nextElementSibling.classList.toggle('hide');
coll[i].childNodes[1].classList.toggle('rotate');
checkActiveCategory = false;
break;
}
}
}
coll[i].addEventListener('click', function() {
var arrow = this.childNodes[1];
arrow.classList.toggle('rotate');
var content = this.nextElementSibling;
content.classList.toggle('hide');
});
}
document.addEventListener('DOMContentLoaded', function() {
createToggler('#navToggler', '#docsNav', 'docsSliderActive');
createToggler('#tocToggler', 'body', 'tocActive');
var headings = document.querySelector('.toc-headings');
headings && headings.addEventListener('click', function(event) {
var el = event.target;
while(el !== headings){
if (el.tagName === 'A') {
document.body.classList.remove('tocActive');
break;
} else{
el = el.parentNode;
}
}
}, false);
function createToggler(togglerSelector, targetSelector, className) {
var toggler = document.querySelector(togglerSelector);
var target = document.querySelector(targetSelector);
if (!toggler) {
return;
}
toggler.onclick = function(event) {
event.preventDefault();
target.classList.toggle(className);
};
}
});
</script></nav></div><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs-site/docs/graph.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">Model</h1></header><article><div><span><!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>The forward and backward propagation in a neural network can be represented
using a set of operations such as convolution and pooling. Each operation takes
some input <a href="./tensor">tensors</a> and applies an <a href="./autograd">operator</a> to generate
output tensors. By representing each operator as a node and each tensor as an
edge, all operations form a computational graph. With the computational graph,
speed and memory optimization can be conducted by scheduling the execution of
the operations and memory allocation/release intelligently. In SINGA, users only
need to define the neural network model using the
<a href="https://github.com/apache/singa/blob/master/python/singa/model.py">Model</a> API.
The graph is constructed and optimized at the C++ backend automatically.</p>
<p>In this way, on the one hand, users implement a network using the
<a href="./graph">Model</a> API following the imperative programming style like PyTorch.
Different from PyTorch which recreates the operations in every iteration, SINGA
buffers the operations to create a computational graph implicitly (when this
feature is enabled) after the first iteration. Therefore, on the other hand,
SINGA has a similar computational graph as the one created by libraries using
declarative programming, e.g., TensorFlow. Consequently, it can enjoy the
optimizations done over the graph.</p>
<h2><a class="anchor" aria-hidden="true" id="example"></a><a href="#example" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Example</h2>
<p>The following code illustrates the usage of the <code>Model</code> API.</p>
<ol>
<li>Implement the new model as a subclass of the Model class.</li>
</ol>
<pre><code class="hljs css language-Python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span><span class="hljs-params">(model.Model)</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, num_classes=<span class="hljs-number">10</span>, num_channels=<span class="hljs-number">1</span>)</span>:</span>
super(CNN, self).__init__()
self.conv1 = layer.Conv2d(num_channels, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>, activation=<span class="hljs-string">"RELU"</span>)
self.conv2 = layer.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>, activation=<span class="hljs-string">"RELU"</span>)
self.linear1 = layer.Linear(<span class="hljs-number">500</span>)
self.linear2 = layer.Linear(num_classes)
self.pooling1 = layer.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
self.pooling2 = layer.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
self.relu = layer.ReLU()
self.flatten = layer.Flatten()
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
y = self.conv1(x)
y = self.pooling1(y)
y = self.conv2(y)
y = self.pooling2(y)
y = self.flatten(y)
y = self.linear1(y)
y = self.relu(y)
y = self.linear2(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_one_batch</span><span class="hljs-params">(self, x, y)</span>:</span>
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)
self.optimizer(loss)
<span class="hljs-keyword">return</span> out, loss
</code></pre>
<ol start="2">
<li>Create an instance of model, optimizer, device, etc. Compile the model</li>
</ol>
<pre><code class="hljs css language-python">model = CNN()
<span class="hljs-comment"># initialize optimizer and attach it to the model</span>
sgd = opt.SGD(lr=<span class="hljs-number">0.005</span>, momentum=<span class="hljs-number">0.9</span>, weight_decay=<span class="hljs-number">1e-5</span>)
model.set_optimizer(sgd)
<span class="hljs-comment"># initialize device</span>
dev = device.create_cuda_gpu()
<span class="hljs-comment"># input and target placeholders for the model</span>
tx = tensor.Tensor((batch_size, <span class="hljs-number">1</span>, IMG_SIZE, IMG_SIZE), dev, tensor.float32)
ty = tensor.Tensor((batch_size, num_classes), dev, tensor.int32)
<span class="hljs-comment"># compile the model before training</span>
model.compile([tx], is_train=<span class="hljs-literal">True</span>, use_graph=<span class="hljs-literal">True</span>, sequential=<span class="hljs-literal">False</span>)
</code></pre>
<ol start="3">
<li>Train the model iteratively</li>
</ol>
<pre><code class="hljs css language-python"><span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(num_train_batch):
<span class="hljs-comment"># generate the next mini-batch</span>
x, y = ...
<span class="hljs-comment"># Copy the data into input tensors</span>
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
<span class="hljs-comment"># Training with one batch</span>
out, loss = model(tx, ty)
</code></pre>
<p>A Google Colab notebook of this example is available
<a href="https://colab.research.google.com/drive/1fbGUs1AsoX6bU5F745RwQpohP4bHTktq">here</a>.</p>
<p>More examples:</p>
<ul>
<li><a href="https://github.com/apache/singa/blob/master/examples/mlp/model.py">MLP</a></li>
<li><a href="https://github.com/apache/singa/blob/master/examples/cnn/model/cnn.py">CNN</a></li>
<li><a href="https://github.com/apache/singa/blob/master/examples/cnn/model/resnet.py">ResNet</a></li>
</ul>
<h2><a class="anchor" aria-hidden="true" id="implementation"></a><a href="#implementation" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Implementation</h2>
<h3><a class="anchor" aria-hidden="true" id="graph-construction"></a><a href="#graph-construction" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Graph Construction</h3>
<p>SINGA constructs the computational graph in three steps:</p>
<ol>
<li>buffer the operations</li>
<li>analyze the dependencies operations</li>
<li>create the nodes and edges based on the dependencies</li>
</ol>
<p>Take the matrix multiplication operation from the dense layer of a
<a href="https://github.com/apache/singa/blob/master/examples/mlp/model.py">MLP model</a>
as an example. The operation is called in the <code>forward</code> function of the MLP
class</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MLP</span><span class="hljs-params">(model.Model)</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, data_size=<span class="hljs-number">10</span>, perceptron_size=<span class="hljs-number">100</span>, num_classes=<span class="hljs-number">10</span>)</span>:</span>
super(MLP, self).__init__()
self.linear1 = layer.Linear(perceptron_size)
...
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, inputs)</span>:</span>
y = self.linear1(inputs)
...
</code></pre>
<p>The <code>Linear</code> layer is composed of the <code>mutmul</code> operator. <code>autograd</code> implements
the <code>matmul</code> operator by calling the function <code>Mult</code> exposed from CPP via SWIG.</p>
<pre><code class="hljs css language-python"><span class="hljs-comment"># implementation of matmul()</span>
singa.Mult(inputs, w)
</code></pre>
<p>At the backend, the <code>Mult</code> function is implemented by calling <code>GEMV</code> a CBLAS
function. Instead of calling <code>GEMV</code> directly, <code>Mult</code> submits <code>GEMV</code> and the
arguments to the device as follows,</p>
<pre><code class="hljs css language-c++"><span class="hljs-comment">// implementation of Mult()</span>
C-&gt;device()-&gt;Exec(
[a, A, b, B, CRef](Context *ctx) <span class="hljs-keyword">mutable</span> {
GEMV&lt;DType, Lang&gt;(a, A, B, b, &amp;CRef, ctx);
},
read_blocks, {C-&gt;block()});
</code></pre>
<p>The <code>Exec</code> function of <code>Device</code> buffers the function and its arguments. In
addition, it also has the information about the blocks (a block is a chunk of
memory for a tensor) to be read and written by this function.</p>
<p>Once <code>Model.forward()</code> has been executed once, all operations are buffered by
<code>Device</code>. Next, the read/write information of all operations are analyzed to
create the computational graph. For example, if a block <code>b</code> is written by one
operation O1 and is later read by another operation O2, we would know O2 depends
on O1 and there is a directed edge from A to B, which represents block <code>b</code> (or
its tensor). After that a directed acyclic graph is constructed as shown below.
The graph is constructed once.</p>
<p><img src="/docs/assets/GraphOfMLP.png" alt="The computational graph of MLP"></p>
<p><br/><strong>Figure 1 - The computational graph of the MLP example.</strong></p>
<h3><a class="anchor" aria-hidden="true" id="optimization"></a><a href="#optimization" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Optimization</h3>
<p>Currently, the following optimizations are done based on the computational
graph.</p>
<p><strong>Lazy allocation</strong> When tensor/blocks are created, devices do not allocate
memory for them immediately. Instead, when the block is accessed for the first
time, the memory is allocated.</p>
<p><strong>Automatic recycling</strong> The reference count of each tensor/block is calculated
based on the graph. Before executing the operations, the reference count is the
number of operations that read this block. During the execution, once an
operation is executed, the reference count of the every input block is decreased
by 1. If one block's reference count reaches 0, it means that this block will
not be read again in the remaining operations. Therefore, its memory can be
released safely. In addition, SINGA tracks the usage of the block outside of the
graph. If a block is used by Python code (not by autograd operators), it will
not be recycled.</p>
<p><strong>Memory sharing</strong> SINGA uses memory pool, e.g.,
<a href="https://github.com/NVIDIA/cnmem">CnMem</a> to manage CUDA memory. With <em>Automatic
recycling</em> and memory pool, SINGA can share the memory among tensors. Consider
two operations <code>c = a + b</code> and <code>d=2xc</code>. Before executing the second operation,
according to <em>Lazy allocation</em>, the memory of d should be allocated. Suppose <code>a</code>
is not used in the rest operations. According to Automatic recycling, the block
of <code>a</code> will be released after the first operation. Therefore, SINGA would submit
four operations to the CUDA stream: addition, free <code>a</code>, malloc <code>b</code>, and
multiplication. The memory pool is then able to share the memory released by <code>a</code>
with <code>b</code> instead of ask the GPU to do real malloc for <code>b</code>.</p>
<p>Other optimization techniques e.g., from compliers, such as common
sub-expression elimination and parallelizing operations on different CUDA
streams can also be applied.</p>
<h2><a class="anchor" aria-hidden="true" id="new-operator"></a><a href="#new-operator" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>New Operator</h2>
<p>Each operator defined in <code>autograd</code> module implements two functions: forward and
backward, which are implemented by calling the operators from the backend. To
add a new operator in <code>autograd</code>, you need to add the multiple operators at the
backend.</p>
<p>Take the
<a href="https://github.com/apache/singa/blob/master/python/singa/autograd.py">Conv2d</a>
operator as an example, at the Python side, the forward and backward function
are implemented by calling the operators from the backend depending on the
device type.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">_Conv2d</span><span class="hljs-params">(Operation)</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x, W, b=None)</span>:</span>
......
<span class="hljs-keyword">if</span> training:
<span class="hljs-keyword">if</span> self.handle.bias_term:
self.inputs = (x, W, b) <span class="hljs-comment"># record x, W, b</span>
<span class="hljs-keyword">else</span>:
self.inputs = (x, W)
<span class="hljs-keyword">if</span> (type(self.handle) != singa.ConvHandle):
<span class="hljs-keyword">return</span> singa.GpuConvForward(x, W, b, self.handle)
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">return</span> singa.CpuConvForward(x, W, b, self.handle)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">backward</span><span class="hljs-params">(self, dy)</span>:</span>
<span class="hljs-keyword">if</span> (type(self.handle) != singa.ConvHandle):
dx = singa.GpuConvBackwardx(dy, self.inputs[<span class="hljs-number">1</span>], self.inputs[<span class="hljs-number">0</span>],
self.handle)
dW = singa.GpuConvBackwardW(dy, self.inputs[<span class="hljs-number">0</span>], self.inputs[<span class="hljs-number">1</span>],
self.handle)
db = singa.GpuConvBackwardb(
dy, self.inputs[<span class="hljs-number">2</span>],
self.handle) <span class="hljs-keyword">if</span> self.handle.bias_term <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span>
<span class="hljs-keyword">else</span>:
dx = singa.CpuConvBackwardx(dy, self.inputs[<span class="hljs-number">1</span>], self.inputs[<span class="hljs-number">0</span>],
self.handle)
dW = singa.CpuConvBackwardW(dy, self.inputs[<span class="hljs-number">0</span>], self.inputs[<span class="hljs-number">1</span>],
self.handle)
db = singa.CpuConvBackwardb(
dy, self.inputs[<span class="hljs-number">2</span>],
self.handle) <span class="hljs-keyword">if</span> self.handle.bias_term <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> db:
<span class="hljs-keyword">return</span> dx, dW, db
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">return</span> dx, dW
</code></pre>
<p>For each operator at the backend, it should be implemented in the following way:</p>
<ul>
<li><p>Suppose the operator is <code>foo()</code>; its real implementation should be wrapped in
another function e.g., <code>_foo()</code>. <code>foo()</code> passes <code>_foo</code> together with the
arguments as a lambda function to <code>Device</code>'s <code>Exec</code> function for buffering.
The blocks to be read and written are also passed to <code>Exec</code>.</p></li>
<li><p>All arguments used in the lambda expression need to be captured according to
the following rules.</p>
<ul>
<li><p><code>capture by value</code>: If the argument variable is a local variable or will be
immediately released (e.g. intermediate tensors). Otherwise, these variables
will be destroyed once <code>foo()</code> exists.</p></li>
<li><p><code>capture by reference</code>:If the variable is recorded on the python side or a
persistent variable (e.g. parameter W and ConvHand in the Conv2d class).</p></li>
<li><p><code>mutable</code>: The lambda expression should have the mutable tag if a variable
captured by value is modified in <code>_foo()</code></p></li>
</ul></li>
</ul>
<p>Here is one
<a href="https://github.com/apache/singa/blob/master/src/model/operation/convolution.cc">example</a>
operator implemented at the backend.</p>
<pre><code class="hljs css language-c++"><span class="hljs-function">Tensor <span class="hljs-title">GpuConvBackwardx</span><span class="hljs-params">(<span class="hljs-keyword">const</span> Tensor &amp;dy, <span class="hljs-keyword">const</span> Tensor &amp;W, <span class="hljs-keyword">const</span> Tensor &amp;x,
<span class="hljs-keyword">const</span> CudnnConvHandle &amp;cch)</span> </span>{
CHECK_EQ(dy.device()-&gt;lang(), kCuda);
Tensor dx;
dx.ResetLike(x);
dy.device()-&gt;Exec(
<span class="hljs-comment">/*
* dx is a local variable so it's captured by value
* dy is an intermediate tensor and isn't recorded on the python side
* W is an intermediate tensor but it's recorded on the python side
* chh is a variable and it's recorded on the python side
*/</span>
[dx, dy, &amp;W, &amp;cch](Context *ctx) <span class="hljs-keyword">mutable</span> {
Block *wblock = W.block(), *dyblock = dy.block(), *dxblock = dx.block();
<span class="hljs-keyword">float</span> alpha = <span class="hljs-number">1.f</span>, beta = <span class="hljs-number">0.f</span>;
cudnnConvolutionBackwardData(
ctx-&gt;cudnn_handle, &amp;alpha, cch.filter_desc, wblock-&gt;data(),
cch.y_desc, dyblock-&gt;data(), cch.conv_desc, cch.bp_data_alg,
cch.workspace.block()-&gt;mutable_data(),
cch.workspace_count * <span class="hljs-keyword">sizeof</span>(<span class="hljs-keyword">float</span>), &amp;beta, cch.x_desc,
dxblock-&gt;mutable_data());
},
{dy.block(), W.block()}, {dx.block(), cch.workspace.block()});
<span class="hljs-comment">/* the lambda expression reads the blocks of tensor dy and w
* and writes the blocks of tensor dx and chh.workspace
*/</span>
<span class="hljs-keyword">return</span> dx;
}
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="benchmark"></a><a href="#benchmark" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Benchmark</h2>
<h3><a class="anchor" aria-hidden="true" id="single-node"></a><a href="#single-node" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Single node</h3>
<ul>
<li>Experiment settings
<ul>
<li>Model
<ul>
<li>Using layer: ResNet50 in
<a href="https://github.com/apache/singa/blob/master/examples/cnn/autograd/resnet_cifar10.py">resnet.py</a></li>
<li>Using model: ResNet50 in
<a href="https://github.com/apache/singa/blob/master/examples/cnn/model/resnet.py">resnet.py</a></li>
</ul></li>
<li>GPU: NVIDIA RTX 2080Ti</li>
</ul></li>
<li>Notations
<ul>
<li><code>s</code> :second</li>
<li><code>it</code> : iteration</li>
<li><code>Mem</code>:peak memory usage of single GPU</li>
<li><code>Throughout</code>:number of images processed per second</li>
<li><code>Time</code>:total time</li>
<li><code>Speed</code>:iterations per second</li>
<li><code>Reduction</code>:the memory usage reduction rate compared with that using layer</li>
<li><code>Speedup</code>: speedup ratio compared with dev branch</li>
</ul></li>
<li>Result
<table style="text-align: center">
<tr>
<th style="text-align: center">Batchsize</th>
<th style="text-align: center">Cases</th>
<th style="text-align: center">Mem(MB)</th>
<th style="text-align: center">Time(s)</th>
<th style="text-align: center">Speed(it/s)</th>
<th style="text-align: center">Throughput</th>
<th style="text-align: center">Reduction</th>
<th style="text-align: center">Speedup</th>
</tr>
<tr>
<td rowspan="4">16</td>
<td nowrap>layer</td>
<td>4975</td>
<td>14.1952</td>
<td>14.0893</td>
<td>225.4285</td>
<td>0.00%</td>
<td>1.0000</td>
</tr>
<tr>
<td nowrap>model:disable graph</td>
<td>4995</td>
<td>14.1264</td>
<td>14.1579</td>
<td>226.5261</td>
<td>-0.40%</td>
<td>1.0049</td>
</tr>
<tr>
<td nowrap>model:enable graph, bfs</td>
<td>3283</td>
<td>13.7438</td>
<td>14.5520</td>
<td>232.8318</td>
<td>34.01%</td>
<td>1.0328</td>
</tr>
<tr>
<td nowrap>model:enable graph, serial</td>
<td>3265</td>
<td>13.7420</td>
<td>14.5540</td>
<td>232.8635</td>
<td>34.37%</td>
<td>1.0330</td>
</tr>
<tr>
<td rowspan="4">32</td>
<td nowrap>layer</td>
<td>10119</td>
<td>13.4587</td>
<td>7.4302</td>
<td>237.7649</td>
<td>0.00%</td>
<td>1.0000</td>
</tr>
<tr>
<td nowrap>model:disable graph</td>
<td>10109</td>
<td>13.2952</td>
<td>7.5315</td>
<td>240.6875</td>
<td>0.10%</td>
<td>1.0123</td>
</tr>
<tr>
<td nowrap>model:enable graph, bfs</td>
<td>6839</td>
<td>13.1059</td>
<td>7.6302</td>
<td>244.1648</td>
<td>32.41%</td>
<td>1.0269</td>
</tr>
<tr>
<td nowrap>model:enable graph, serial</td>
<td>6845</td>
<td>13.0489</td>
<td>7.6635</td>
<td>245.2312</td>
<td>32.35%</td>
<td>1.0314</td>
</tr>
</table>
</li>
</ul>
<h3><a class="anchor" aria-hidden="true" id="multi-processes"></a><a href="#multi-processes" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Multi processes</h3>
<ul>
<li>Experiment settings
<ul>
<li>API
<ul>
<li>using Layer: ResNet50 in
<a href="https://github.com/apache/singa/blob/master/examples/cnn/autograd/resnet_dist.py">resnet_dist.py</a></li>
<li>using Model: ResNet50 in
<a href="https://github.com/apache/singa/blob/master/examples/cnn/model/resnet.py">resnet.py</a></li>
</ul></li>
<li>GPU: NVIDIA RTX 2080Ti * 2</li>
<li>MPI: two MPI processes on one node</li>
</ul></li>
<li>Notations: the same as above</li>
<li>Result
<table style="text-align: center">
<tr>
<th style="text-align: center">Batchsize</th>
<th style="text-align: center">Cases</th>
<th style="text-align: center">Mem(MB)</th>
<th style="text-align: center">Time(s)</th>
<th style="text-align: center">Speed(it/s)</th>
<th style="text-align: center">Throughput</th>
<th style="text-align: center">Reduction</th>
<th style="text-align: center">Speedup</th>
</tr>
<tr>
<td rowspan="4">16</td>
<td nowrap>layer</td>
<td>5439</td>
<td>17.3323</td>
<td>11.5391</td>
<td>369.2522</td>
<td>0.00%</td>
<td>1.0000</td>
</tr>
<tr>
<td nowrap>model:disable graph</td>
<td>5427</td>
<td>17.8232</td>
<td>11.2213</td>
<td>359.0831</td>
<td>0.22%</td>
<td>0.9725</td>
</tr>
<tr>
<td nowrap>model:enable graph, bfs</td>
<td>3389</td>
<td>18.2310</td>
<td>10.9703</td>
<td>351.0504</td>
<td>37.69%</td>
<td>0.9507</td>
</tr>
<tr>
<td nowrap>model:enable graph, serial</td>
<td>3437</td>
<td>17.0389</td>
<td>11.7378</td>
<td>375.6103</td>
<td>36.81%</td>
<td>1.0172</td>
</tr>
<tr>
<td rowspan="4">32</td>
<td nowrap>layer</td>
<td>10547</td>
<td>14.8635</td>
<td>6.7279</td>
<td>430.5858</td>
<td>0.00%</td>
<td>1.0000</td>
</tr>
<tr>
<td nowrap>model:disable graph</td>
<td>10503</td>
<td>14.7746</td>
<td>6.7684</td>
<td>433.1748</td>
<td>0.42%</td>
<td>1.0060</td>
</tr>
<tr>
<td nowrap>model:enable graph, bfs</td>
<td>6935</td>
<td>14.8553</td>
<td>6.7316</td>
<td>430.8231</td>
<td>34.25%</td>
<td>1.0006</td>
</tr>
<tr>
<td nowrap>model:enable graph, serial</td>
<td>7027</td>
<td>14.3271</td>
<td>6.9798</td>
<td>446.7074</td>
<td>33.37%</td>
<td>1.0374</td>
</tr>
</table>
</li>
</ul>
<h3><a class="anchor" aria-hidden="true" id="conclusion"></a><a href="#conclusion" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Conclusion</h3>
<ul>
<li>Training with the computational graph enabled can significantly reduce the
memory footprint.</li>
<li>Currently, there is a little improvement in terms of speed. More optimizations
can be done towards the efficiency.</li>
</ul>
</span></div></article></div><div class="docLastUpdate"><em>Last updated on 10/4/2020</em></div><div class="docs-prevnext"><a class="docs-prev button" href="/docs/3.2.0.rc1/optimizer"><span class="arrow-prev"></span><span>Optimizer</span></a><a class="docs-next button" href="/docs/3.2.0.rc1/onnx"><span>ONNX</span><span class="arrow-next"></span></a></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#example">Example</a></li><li><a href="#implementation">Implementation</a><ul class="toc-headings"><li><a href="#graph-construction">Graph Construction</a></li><li><a href="#optimization">Optimization</a></li></ul></li><li><a href="#new-operator">New Operator</a></li><li><a href="#benchmark">Benchmark</a><ul class="toc-headings"><li><a href="#single-node">Single node</a></li><li><a href="#multi-processes">Multi processes</a></li><li><a href="#conclusion">Conclusion</a></li></ul></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/https://apache-singa.readthedocs.io/en/latest/">API Reference</a><a href="/docs/examples">Examples</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/blog">SINGA News</a><a href="https://github.com/apache/singa">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2023
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script type="text/javascript" src="https://cdn.jsdelivr.net/docsearch.js/1/docsearch.min.js"></script><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script><script>
document.addEventListener('keyup', function(e) {
if (e.target !== document.body) {
return;
}
// keyCode for '/' (slash)
if (e.keyCode === 191) {
const search = document.getElementById('search_input_react');
search && search.focus();
}
});
</script><script>
var search = docsearch({
apiKey: '45202133606c0b5fa6d21cddc4725dd8',
indexName: 'apache_singa',
inputSelector: '#search_input_react',
algoliaOptions: {"facetFilters":["language:en","version:3.0.0"]}
});
</script></body></html>