blob: 12e52273cab98378db21678fa0a7af945f24b6e0 [file] [log] [blame]
<!DOCTYPE html><html lang="en"><head><meta charSet="utf-8"/><meta http-equiv="X-UA-Compatible" content="IE=edge"/><title>ONNX · Apache SINGA</title><meta name="viewport" content="width=device-width"/><meta name="generator" content="Docusaurus"/><meta name="description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta name="docsearch:version" content="2.0.0"/><meta name="docsearch:language" content="en"/><meta property="og:title" content="ONNX · Apache SINGA"/><meta property="og:type" content="website"/><meta property="og:url" content="https://feynmandna.github.io/"/><meta property="og:description" content="&lt;!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the &quot;License&quot;); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an &quot;AS IS&quot; BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --&gt;"/><meta property="og:image" content="https://feynmandna.github.io/img/singa_twitter_banner.jpeg"/><meta name="twitter:card" content="summary"/><meta name="twitter:image" content="https://feynmandna.github.io/img/singa_twitter_banner.jpeg"/><link rel="shortcut icon" href="/img/favicon.ico"/><link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/atom-one-dark.min.css"/><link rel="alternate" type="application/atom+xml" href="https://feynmandna.github.io/blog/atom.xml" title="Apache SINGA Blog ATOM Feed"/><link rel="alternate" type="application/rss+xml" href="https://feynmandna.github.io/blog/feed.xml" title="Apache SINGA Blog RSS Feed"/><script type="text/javascript" src="https://buttons.github.io/buttons.js"></script><script src="https://unpkg.com/vanilla-back-to-top@7.1.14/dist/vanilla-back-to-top.min.js"></script><script>
document.addEventListener('DOMContentLoaded', function() {
addBackToTop(
{"zIndex":100}
)
});
</script><script src="/js/scrollSpy.js"></script><link rel="stylesheet" href="/css/main.css"/><script src="/js/codetabs.js"></script></head><body class="sideNavVisible separateOnPageNav"><div class="fixedHeaderContainer"><div class="headerWrapper wrapper"><header><a href="/"><img class="logo" src="/img/singa.png" alt="Apache SINGA"/></a><a href="/versions"><h3>2.0.0</h3></a><div class="navigationWrapper navigationSlider"><nav class="slidingNav"><ul class="nav-site nav-site-internal"><li class=""><a href="/docs/2.0.0/installation" target="_self">Docs</a></li><li class=""><a href="/docs/2.0.0/source-repository" target="_self">Community</a></li><li class=""><a href="/blog/" target="_self">News</a></li><li class=""><a href="https://apache-singa.readthedocs.io/en/latest/" target="_self">API</a></li><li class=""><a target="_self"></a></li><li class=""><a href="https://github.com/apache/singa-doc" target="_self">GitHub</a></li></ul></nav></div></header></div></div><div class="navPusher"><div class="docMainWrapper wrapper"><div class="container mainContainer docsContainer"><div class="wrapper"><div class="post"><header class="postHeader"><a class="edit-page-link button" href="https://github.com/apache/singa-doc/blob/master/docs/onnx.md" target="_blank" rel="noreferrer noopener">Edit</a><h1 id="__docusaurus" class="postHeaderTitle">ONNX</h1></header><article><div><span><!--- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<p>ONNX is an open format built to represent machine learning models, which enables an ability to transfer trained models between different deep learning frameworks. We have integrated the main functionality of ONNX into Singa, and several basic operators have been supported. More operators are being developing.</p>
<h2><a class="anchor" aria-hidden="true" id="example-onnx-mnist-on-singa"></a><a href="#example-onnx-mnist-on-singa" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Example: ONNX mnist on singa</h2>
<p>We will introduce the onnx of singa by using the mnist example. In this section, the examples of how to export, load, inference, re-training, and transfer-learning the minist model will be displayed.</p>
<h3><a class="anchor" aria-hidden="true" id="load-dataset"></a><a href="#load-dataset" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Load dataset</h3>
<p>Firstly, we import some necessary libraries and define some auxiliary functions for downloading and preprocessing the dataset:</p>
<pre><code class="hljs css language-python"><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> urllib.request
<span class="hljs-keyword">import</span> gzip
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> codecs
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> device
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> tensor
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> opt
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> autograd
<span class="hljs-keyword">from</span> singa <span class="hljs-keyword">import</span> sonnx
<span class="hljs-keyword">import</span> onnx
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">load_dataset</span><span class="hljs-params">()</span>:</span>
train_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'</span>
train_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'</span>
valid_x_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'</span>
valid_y_url = <span class="hljs-string">'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'</span>
train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
np.float32)
train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
np.float32)
valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
np.float32)
valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
np.float32)
<span class="hljs-keyword">return</span> train_x, train_y, valid_x, valid_y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">check_exist_or_download</span><span class="hljs-params">(url)</span>:</span>
download_dir = <span class="hljs-string">'/tmp/'</span>
name = url.rsplit(<span class="hljs-string">'/'</span>, <span class="hljs-number">1</span>)[<span class="hljs-number">-1</span>]
filename = os.path.join(download_dir, name)
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.isfile(filename):
print(<span class="hljs-string">"Downloading %s"</span> % url)
urllib.request.urlretrieve(url, filename)
<span class="hljs-keyword">return</span> filename
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_label_file</span><span class="hljs-params">(path)</span>:</span>
<span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
data = f.read()
<span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2049</span>
length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">8</span>).reshape(
(length))
<span class="hljs-keyword">return</span> parsed
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_int</span><span class="hljs-params">(b)</span>:</span>
<span class="hljs-keyword">return</span> int(codecs.encode(b, <span class="hljs-string">'hex'</span>), <span class="hljs-number">16</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">read_image_file</span><span class="hljs-params">(path)</span>:</span>
<span class="hljs-keyword">with</span> gzip.open(path, <span class="hljs-string">'rb'</span>) <span class="hljs-keyword">as</span> f:
data = f.read()
<span class="hljs-keyword">assert</span> get_int(data[:<span class="hljs-number">4</span>]) == <span class="hljs-number">2051</span>
length = get_int(data[<span class="hljs-number">4</span>:<span class="hljs-number">8</span>])
num_rows = get_int(data[<span class="hljs-number">8</span>:<span class="hljs-number">12</span>])
num_cols = get_int(data[<span class="hljs-number">12</span>:<span class="hljs-number">16</span>])
parsed = np.frombuffer(data, dtype=np.uint8, offset=<span class="hljs-number">16</span>).reshape(
(length, <span class="hljs-number">1</span>, num_rows, num_cols))
<span class="hljs-keyword">return</span> parsed
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">to_categorical</span><span class="hljs-params">(y, num_classes)</span>:</span>
y = np.array(y, dtype=<span class="hljs-string">"int"</span>)
n = y.shape[<span class="hljs-number">0</span>]
categorical = np.zeros((n, num_classes))
categorical[np.arange(n), y] = <span class="hljs-number">1</span>
categorical = categorical.astype(np.float32)
<span class="hljs-keyword">return</span> categorical
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="mnist-model"></a><a href="#mnist-model" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>MNIST model</h3>
<p>We define a class called <strong>CNN</strong> to construct the mnist model which consists of several convolution, pooling, fully connection and relu layers. We also define a function to calculate the <strong>accuracy</strong> of our result. Finally, we define a <strong>train</strong> and a <strong>test</strong> function to handle the training and prediction process.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self)</span>:</span>
self.conv1 = autograd.Conv2d(<span class="hljs-number">1</span>, <span class="hljs-number">20</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
self.conv2 = autograd.Conv2d(<span class="hljs-number">20</span>, <span class="hljs-number">50</span>, <span class="hljs-number">5</span>, padding=<span class="hljs-number">0</span>)
self.linear1 = autograd.Linear(<span class="hljs-number">4</span> * <span class="hljs-number">4</span> * <span class="hljs-number">50</span>, <span class="hljs-number">500</span>, bias=<span class="hljs-literal">False</span>)
self.linear2 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
self.pooling1 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
self.pooling2 = autograd.MaxPool2d(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, padding=<span class="hljs-number">0</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
y = self.conv1(x)
y = autograd.relu(y)
y = self.pooling1(y)
y = self.conv2(y)
y = autograd.relu(y)
y = self.pooling2(y)
y = autograd.flatten(y)
y = self.linear1(y)
y = autograd.relu(y)
y = self.linear2(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">accuracy</span><span class="hljs-params">(pred, target)</span>:</span>
y = np.argmax(pred, axis=<span class="hljs-number">1</span>)
t = np.argmax(target, axis=<span class="hljs-number">1</span>)
a = y == t
<span class="hljs-keyword">return</span> np.array(a, <span class="hljs-string">"int"</span>).sum() / float(len(t))
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train</span><span class="hljs-params">(model,
x,
y,
epochs=<span class="hljs-number">1</span>,
batch_size=<span class="hljs-number">64</span>,
dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = model.forward(x_batch)
<span class="hljs-comment"># onnx_model = sonnx.to_onnx([x_batch], [y])</span>
<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
loss = autograd.softmax_cross_entropy(output_batch, target_batch)
accuracy_rate = accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
sgd = opt.SGD(lr=<span class="hljs-number">0.001</span>)
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, gp)
sgd.step()
<span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
(accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
print(<span class="hljs-string">"training completed"</span>)
<span class="hljs-keyword">return</span> x_batch, output_batch
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(model, x, y, batch_size=<span class="hljs-number">64</span>, dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
result = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = model.forward(x_batch)
result += accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
print(<span class="hljs-string">"testing acc %6.2f"</span> % (result / batch_number))
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="train-mnist-model-and-export-it-to-onnx"></a><a href="#train-mnist-model-and-export-it-to-onnx" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Train mnist model and export it to onnx</h3>
<p>Now, we can train the mnist model and export its onnx model by calling the <strong>soonx.to_onnx</strong> function.</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">make_onnx</span><span class="hljs-params">(x, y)</span>:</span>
<span class="hljs-keyword">return</span> sonnx.to_onnx([x], [y])
<span class="hljs-comment"># create device</span>
dev = device.create_cuda_gpu()
<span class="hljs-comment">#dev = device.get_default_device()</span>
<span class="hljs-comment"># create model</span>
model = CNN()
<span class="hljs-comment"># load data</span>
train_x, train_y, valid_x, valid_y = load_dataset()
<span class="hljs-comment"># normalization</span>
train_x = train_x / <span class="hljs-number">255</span>
valid_x = valid_x / <span class="hljs-number">255</span>
train_y = to_categorical(train_y, <span class="hljs-number">10</span>)
valid_y = to_categorical(valid_y, <span class="hljs-number">10</span>)
<span class="hljs-comment"># do training</span>
autograd.training = <span class="hljs-literal">True</span>
x, y = train(model, train_x, train_y, dev=dev)
onnx_model = make_onnx(x, y)
<span class="hljs-comment"># print('The model is:\n{}'.format(onnx_model))</span>
<span class="hljs-comment"># Save the ONNX model</span>
model_path = os.path.join(<span class="hljs-string">'/'</span>, <span class="hljs-string">'tmp'</span>, <span class="hljs-string">'mnist.onnx'</span>)
onnx.save(onnx_model, model_path)
print(<span class="hljs-string">'The model is saved.'</span>)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="inference"></a><a href="#inference" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Inference</h3>
<p>After we export the onnx model, we can find a file called <strong>mnist.onnx</strong> in the '/tmp' directory, this model, therefore, can be imported by other libraries. Now, if we want to import this onnx model into singa again and do the inference using the validation dataset, we can define a class called <strong>Infer</strong>, the forward function of Infer will be called by the test function to do inference for validation dataset. By the way, we should set the label of training to <strong>False</strong> to fix the gradient of autograd operators.</p>
<p>When import the onnx model, we firstly call <strong>onnx.load</strong> to load the onnx model. Then the onnx model will be fed into the <strong>soonx.prepare</strong> to parse and initiate to a singa model(<strong>sg_ir</strong> in the code). The sg_ir contains a singa graph within it, and we can run an step of inference by feeding input to its run function.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Infer</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir)</span>:</span>
self.sg_ir = sg_ir
<span class="hljs-keyword">for</span> idx, tens <span class="hljs-keyword">in</span> sg_ir.tensor_map.items():
<span class="hljs-comment"># allow the tensors to be updated</span>
tens.requires_grad = <span class="hljs-literal">True</span>
tens.stores_grad= <span class="hljs-literal">True</span>
sg_ir.tensor_map[idx] = tens
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
<span class="hljs-keyword">return</span> sg_ir.run([x])[<span class="hljs-number">0</span>] <span class="hljs-comment"># we can run one step of inference by feeding input</span>
<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev) <span class="hljs-comment"># parse and initiate to a singa model</span>
<span class="hljs-comment"># inference</span>
autograd.training = <span class="hljs-literal">False</span>
print(<span class="hljs-string">'The inference result is:'</span>)
test(Infer(sg_ir), valid_x, valid_y, dev=dev)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="re-training"></a><a href="#re-training" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Re-training</h3>
<p>Assume after import the model, we want to re-train the model again, we can define a function called <strong>re_train</strong>. Before we call this re_train function, we should set the label of training to <strong>True</strong> to make the autograde operators update their gradient. And after we finish the training, we set it as <strong>False</strong> again to call the test function doing inference.</p>
<pre><code class="hljs css language-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">re_train</span><span class="hljs-params">(sg_ir,
x,
y,
epochs=<span class="hljs-number">1</span>,
batch_size=<span class="hljs-number">64</span>,
dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
new_model = Infer(sg_ir)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = new_model.forward(x_batch)
loss = autograd.softmax_cross_entropy(output_batch, target_batch)
accuracy_rate = accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
sgd = opt.SGD(lr=<span class="hljs-number">0.01</span>)
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, gp)
sgd.step()
<span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
(accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
print(<span class="hljs-string">"re-training completed"</span>)
<span class="hljs-keyword">return</span> new_model
<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev)
<span class="hljs-comment"># re-training</span>
autograd.training = <span class="hljs-literal">True</span>
new_model = re_train(sg_ir, train_x, train_y, dev=dev)
autograd.training = <span class="hljs-literal">False</span>
test(new_model, valid_x, valid_y, dev=dev)
</code></pre>
<h3><a class="anchor" aria-hidden="true" id="transfer-learning"></a><a href="#transfer-learning" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Transfer learning</h3>
<p>Finally, if we want to do transfer-learning, we can define a function called <strong>Trans</strong> to append some layers after the onnx model. For demonstration, we only append several linear(fully connection) and relu after the onnx model. We also define a transfer_learning function to handle the training process of the transfer-learning model. And the label of training is the same as the previous one.</p>
<pre><code class="hljs css language-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Trans</span>:</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, sg_ir, last_layers)</span>:</span>
self.sg_ir = sg_ir
self.last_layers = last_layers
self.append_linear1 = autograd.Linear(<span class="hljs-number">500</span>, <span class="hljs-number">128</span>, bias=<span class="hljs-literal">False</span>)
self.append_linear2 = autograd.Linear(<span class="hljs-number">128</span>, <span class="hljs-number">32</span>, bias=<span class="hljs-literal">False</span>)
self.append_linear3 = autograd.Linear(<span class="hljs-number">32</span>, <span class="hljs-number">10</span>, bias=<span class="hljs-literal">False</span>)
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
y = sg_ir.run([x], last_layers=self.last_layers)[<span class="hljs-number">0</span>]
y = self.append_linear1(y)
y = autograd.relu(y)
y = self.append_linear2(y)
y = autograd.relu(y)
y = self.append_linear3(y)
y = autograd.relu(y)
<span class="hljs-keyword">return</span> y
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">transfer_learning</span><span class="hljs-params">(sg_ir,
x,
y,
epochs=<span class="hljs-number">1</span>,
batch_size=<span class="hljs-number">64</span>,
dev=device.get_default_device<span class="hljs-params">()</span>)</span>:</span>
batch_number = x.shape[<span class="hljs-number">0</span>] // batch_size
trans_model = Trans(sg_ir, <span class="hljs-number">-1</span>)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(epochs):
<span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> range(batch_number):
l_idx = b * batch_size
r_idx = (b + <span class="hljs-number">1</span>) * batch_size
x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
output_batch = trans_model.forward(x_batch)
loss = autograd.softmax_cross_entropy(output_batch, target_batch)
accuracy_rate = accuracy(tensor.to_numpy(output_batch),
tensor.to_numpy(target_batch))
sgd = opt.SGD(lr=<span class="hljs-number">0.07</span>)
<span class="hljs-keyword">for</span> p, gp <span class="hljs-keyword">in</span> autograd.backward(loss):
sgd.update(p, gp)
sgd.step()
<span class="hljs-keyword">if</span> b % <span class="hljs-number">1e2</span> == <span class="hljs-number">0</span>:
print(<span class="hljs-string">"acc %6.2f loss, %6.2f"</span> %
(accuracy_rate, tensor.to_numpy(loss)[<span class="hljs-number">0</span>]))
print(<span class="hljs-string">"transfer-learning completed"</span>)
<span class="hljs-keyword">return</span> trans_mode
<span class="hljs-comment"># load the ONNX model</span>
onnx_model = onnx.load(model_path)
sg_ir = sonnx.prepare(onnx_model, device=dev)
<span class="hljs-comment"># transfer-learning</span>
autograd.training = <span class="hljs-literal">True</span>
new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
autograd.training = <span class="hljs-literal">False</span>
test(new_model, valid_x, valid_y, dev=dev)
</code></pre>
<h2><a class="anchor" aria-hidden="true" id="supported-operators"></a><a href="#supported-operators" aria-hidden="true" class="hash-link"><svg class="hash-link-icon" aria-hidden="true" height="16" version="1.1" viewBox="0 0 16 16" width="16"><path fill-rule="evenodd" d="M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"></path></svg></a>Supported operators</h2>
<p>The following operators are supported:</p>
<ul>
<li>Conv</li>
<li>Relu</li>
<li>Constant</li>
<li>MaxPool</li>
<li>AveragePool</li>
<li>Softmax</li>
<li>Sigmoid</li>
<li>Add</li>
<li>MatMul</li>
<li>BatchNormalization</li>
<li>Concat</li>
<li>Flatten</li>
<li>Add</li>
<li>Gemm</li>
<li>Reshape</li>
<li>Sum</li>
<li>Cos</li>
<li>Cosh</li>
<li>Sin</li>
<li>Sinh</li>
<li>Tan</li>
<li>Tanh</li>
<li>Acos</li>
<li>Acosh</li>
<li>Asin</li>
<li>Asinh</li>
<li>Atan</li>
<li>Atanh</li>
<li>Selu</li>
<li>Elu</li>
<li>Equal</li>
<li>Less</li>
<li>Sign</li>
<li>Div</li>
<li>Sub</li>
<li>Sqrt</li>
<li>Log</li>
<li>Greater</li>
<li>HardSigmoid</li>
<li>Identity</li>
<li>Softplus</li>
<li>Softsign</li>
<li>Mean</li>
<li>Pow</li>
<li>Clip</li>
<li>PRelu</li>
<li>Mul</li>
<li>Transpose</li>
<li>Max</li>
<li>Min</li>
<li>Shape</li>
<li>And</li>
<li>Or</li>
<li>Xor</li>
<li>Not</li>
<li>Neg</li>
<li>Reciprocal</li>
</ul>
</span></div></article></div><div class="docs-prevnext"></div></div></div><nav class="onPageNav"><ul class="toc-headings"><li><a href="#example-onnx-mnist-on-singa">Example: ONNX mnist on singa</a><ul class="toc-headings"><li><a href="#load-dataset">Load dataset</a></li><li><a href="#mnist-model">MNIST model</a></li><li><a href="#train-mnist-model-and-export-it-to-onnx">Train mnist model and export it to onnx</a></li><li><a href="#inference">Inference</a></li><li><a href="#re-training">Re-training</a></li><li><a href="#transfer-learning">Transfer learning</a></li></ul></li><li><a href="#supported-operators">Supported operators</a></li></ul></nav></div><footer class="nav-footer" id="footer"><section class="sitemap"><a href="/" class="nav-home"><img src="/img/singa-logo-square.png" alt="Apache SINGA" width="66" height="58"/></a><div><h5>Docs</h5><a href="/docs/installation">Getting Started</a><a href="/docs/device">Guides</a><a href="/en/#">API Reference (coming soon)</a><a href="/docs/model-zoo-cnn-cifar10">Model Zoo</a><a href="/docs/download-singa">Development</a></div><div><h5>Community</h5><a href="/en/users.html">User Showcase</a><a href="/docs/history-singa">SINGA History</a><a href="/docs/team-list">SINGA Team</a><a href="/news">SINGA News</a><a href="https://github.com/apache/singa-doc">GitHub</a><div class="social"><a class="github-button" href="https://github.com/apache/singa-doc" data-count-href="/apache/singa/stargazers" data-show-count="true" data-count-aria-label="# stargazers on GitHub" aria-label="Star this project on GitHub">apache/singa-doc</a></div><div class="social"><a href="https://twitter.com/ApacheSINGA" class="twitter-follow-button">Follow @ApacheSINGA</a></div></div><div><h5>Apache Software Foundation</h5><a href="https://apache.org/" target="_blank" rel="noreferrer noopener">Foundation</a><a href="http://www.apache.org/licenses/" target="_blank" rel="noreferrer noopener">License</a><a href="http://www.apache.org/foundation/sponsorship.html" target="_blank" rel="noreferrer noopener">Sponsorship</a><a href="http://www.apache.org/foundation/thanks.html" target="_blank" rel="noreferrer noopener">Thanks</a><a href="http://www.apache.org/events/current-event" target="_blank" rel="noreferrer noopener">Events</a><a href="http://www.apache.org/security/" target="_blank" rel="noreferrer noopener">Security</a></div></section><div style="width:100%;text-align:center"><a href="https://apache.org/" target="_blank" rel="noreferrer noopener" class="ApacheOpenSource"><img src="/img/asf_logo_wide.svg" alt="Apache Open Source"/></a><section class="copyright" style="max-width:60%;margin:0 auto">Copyright © 2020
The Apache Software Foundation. All rights reserved.
Apache SINGA, Apache, the Apache feather logo, and
the Apache SINGA project logos are trademarks of The
Apache Software Foundation. All other marks mentioned
may be trademarks or registered trademarks of their
respective owners.</section></div></footer></div><script>window.twttr=(function(d,s, id){var js,fjs=d.getElementsByTagName(s)[0],t=window.twttr||{};if(d.getElementById(id))return t;js=d.createElement(s);js.id=id;js.src='https://platform.twitter.com/widgets.js';fjs.parentNode.insertBefore(js, fjs);t._e = [];t.ready = function(f) {t._e.push(f);};return t;}(document, 'script', 'twitter-wjs'));</script></body></html>