blob: caa265ded84964961fa43db5633f4492a86a5e3d [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="Hybridize Gluon models with control flows." property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="Hybridize Gluon models with control flows." property="og:description"/>
<title>Hybridize Gluon models with control flows. — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/1.4.1/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../genindex.html" rel="index" title="Index">
<link href="../../search.html" rel="search" title="Search"/>
<link href="index.html" rel="up" title="Tutorials"/>
<link href="../embedded/index.html" rel="next" title="Tutorials"/>
<link href="index.html" rel="prev" title="Tutorials"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/1.4.1/install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.4.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.4.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/java/index.html">Java</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/1.4.1/faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.4.1/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.4.1">Github</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/community/ecosystem.html">Ecosystem</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">1.4.1<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/1.4.1/install/index.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.4.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.4.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/java/index.html">Java</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.4.1/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/1.4.1/faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/1.4.1/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.4.1/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/1.4.1/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/1.4.1/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.4.1" tabindex="-1">Github</a></li>
<li><a href="/versions/1.4.1/community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="/versions/1.4.1/community/ecosystem.html" tabindex="-1">Ecosystem</a></li>
<li><a href="/versions/1.4.1/community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">1.4.1</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/index.html">MXNet APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">MXNet Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../community/index.html">MXNet Community</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">MXNet FAQ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../gluon/index.html">About Gluon</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">Installing MXNet</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html#nvidia-jetson-tx-family">Nvidia Jetson TX family</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html#source-download">Source Download</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../model_zoo/index.html">MXNet Model Zoo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at --><!--- http://www.apache.org/licenses/LICENSE-2.0 --><!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. --><div class="section" id="hybridize-gluon-models-with-control-flows">
<span id="hybridize-gluon-models-with-control-flows"></span><h1>Hybridize Gluon models with control flows.<a class="headerlink" href="#hybridize-gluon-models-with-control-flows" title="Permalink to this headline"></a></h1>
<p>MXNet currently provides three control flow operators: <code class="docutils literal"><span class="pre">cond</span></code>, <code class="docutils literal"><span class="pre">foreach</span></code> and <code class="docutils literal"><span class="pre">while_loop</span></code>. Like other MXNet operators, they all have a version for NDArray and a version for Symbol. These two versions have exactly the same semantics. We can take advantage of this and use them in Gluon to hybridize models.</p>
<p>In this tutorial, we use a few examples to demonstrate the use of control flow operators in Gluon and show how a model that requires control flow is hybridized.</p>
<div class="section" id="prepare-running-the-code">
<span id="prepare-running-the-code"></span><h2>Prepare running the code<a class="headerlink" href="#prepare-running-the-code" title="Permalink to this headline"></a></h2>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon</span> <span class="kn">import</span> <span class="n">HybridBlock</span>
</pre></div>
</div>
</div>
<div class="section" id="foreach">
<span id="foreach"></span><h2>foreach<a class="headerlink" href="#foreach" title="Permalink to this headline"></a></h2>
<p><code class="docutils literal"><span class="pre">foreach</span></code> is a for loop that iterates over the first dimension of the input data (it can be an array or a list of arrays). It is defined with the following signature:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">foreach</span><span class="p">(</span><span class="n">body</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">init_states</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span> <span class="o">=></span> <span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
</pre></div>
</div>
<p>It runs the Python function defined in <code class="docutils literal"><span class="pre">body</span></code> for every slice from the input arrays. The signature of the <code class="docutils literal"><span class="pre">body</span></code> function is defined as follows:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">body</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span> <span class="o">=></span> <span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
</pre></div>
</div>
<p>The inputs of the <code class="docutils literal"><span class="pre">body</span></code> function have two parts: <code class="docutils literal"><span class="pre">data</span></code> is a slice of an array (if there is only one input array in <code class="docutils literal"><span class="pre">foreach</span></code>) or a list of slices (if there are a list of input arrays); <code class="docutils literal"><span class="pre">states</span></code> are the arrays from the previous iteration. The outputs of the <code class="docutils literal"><span class="pre">body</span></code> function also have two parts: <code class="docutils literal"><span class="pre">outputs</span></code> is an array or a list of arrays; <code class="docutils literal"><span class="pre">states</span></code> is the computation states of the current iteration. <code class="docutils literal"><span class="pre">outputs</span></code> from all iterations are concatenated as the outputs of <code class="docutils literal"><span class="pre">foreach</span></code>.</p>
<p>The following pseudocode illustrates the execution of <code class="docutils literal"><span class="pre">foreach</span></code>.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">foreach</span><span class="p">(</span><span class="n">body</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">init_states</span><span class="p">):</span>
<span class="n">states</span> <span class="o">=</span> <span class="n">init_states</span>
<span class="n">outs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">out</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">body</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
<span class="n">outs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="n">outs</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="o">*</span><span class="n">outs</span><span class="p">)</span>
<span class="k">return</span> <span class="n">outs</span><span class="p">,</span> <span class="n">states</span>
</pre></div>
</div>
<div class="section" id="example-1-foreach-works-like-map">
<span id="example-1-foreach-works-like-map"></span><h3>Example 1: <code class="docutils literal"><span class="pre">foreach</span></code> works like map<a class="headerlink" href="#example-1-foreach-works-like-map" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">foreach</span></code> can work like a map function of a functional language. In this case, the states of <code class="docutils literal"><span class="pre">foreach</span></code> can be an empty list, which means the computation doesn’t carry computation states across iterations.</p>
<p>In this example, we use <code class="docutils literal"><span class="pre">foreach</span></code> to increase each element’s value of an array by one.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[</span> <span class="mf">0.</span> <span class="mf">1.</span> <span class="mf">2.</span> <span class="mf">3.</span> <span class="mf">4.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add1</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
<span class="k">return</span> <span class="n">data</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">[]</span>
<span class="k">class</span> <span class="nc">Map</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">foreach</span><span class="p">(</span><span class="n">add1</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="p">[])</span>
<span class="k">return</span> <span class="n">out</span>
<span class="n">map_layer</span> <span class="o">=</span> <span class="n">Map</span><span class="p">()</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">map_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[[</span> <span class="mf">1.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">2.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">3.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">4.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">5.</span><span class="p">]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
<p>We can hybridize the block and run the computation again. It should generate the same result.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">map_layer</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">map_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[[</span> <span class="mf">1.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">2.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">3.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">4.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">5.</span><span class="p">]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
</div>
<div class="section" id="example-2-foreach-works-like-scan">
<span id="example-2-foreach-works-like-scan"></span><h3>Example 2: <code class="docutils literal"><span class="pre">foreach</span></code> works like scan<a class="headerlink" href="#example-2-foreach-works-like-scan" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">foreach</span></code> can work like a scan function in a functional language. In this case, the outputs of the Python function is an empty list.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[],</span> <span class="n">state</span> <span class="o">+</span> <span class="n">data</span>
<span class="k">class</span> <span class="nc">Scan</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="n">_</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">foreach</span><span class="p">(</span><span class="nb">sum</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">F</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">state</span>
<span class="n">scan_layer</span> <span class="o">=</span> <span class="n">Scan</span><span class="p">()</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">scan_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[</span> <span class="mf">0.</span> <span class="mf">1.</span> <span class="mf">2.</span> <span class="mf">3.</span> <span class="mf">4.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
<span class="p">[</span> <span class="mf">10.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">scan_layer</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">scan_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[</span> <span class="mf">10.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
</div>
<div class="section" id="example-3-foreach-with-both-outputs-and-states">
<span id="example-3-foreach-with-both-outputs-and-states"></span><h3>Example 3: <code class="docutils literal"><span class="pre">foreach</span></code> with both outputs and states<a class="headerlink" href="#example-3-foreach-with-both-outputs-and-states" title="Permalink to this headline"></a></h3>
<p>This is probably the most common use case of <code class="docutils literal"><span class="pre">foreach</span></code>. We extend the previous scan example and return both output and states.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
<span class="k">return</span> <span class="n">state</span> <span class="o">+</span> <span class="n">data</span><span class="p">,</span> <span class="n">state</span> <span class="o">+</span> <span class="n">data</span>
<span class="k">class</span> <span class="nc">ScanV2</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">foreach</span><span class="p">(</span><span class="nb">sum</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">F</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">state</span>
<span class="n">scan_layer</span> <span class="o">=</span> <span class="n">ScanV2</span><span class="p">()</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">scan_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[[</span> <span class="mf">0.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">1.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">3.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">6.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">10.</span><span class="p">]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
<span class="p">[</span> <span class="mf">10.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">scan_layer</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">scan_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[[</span> <span class="mf">0.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">1.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">3.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">6.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">10.</span><span class="p">]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
<span class="p">[</span> <span class="mf">10.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
</div>
<div class="section" id="example-4-use-foreach-to-run-an-rnn-on-a-variable-length-sequence">
<span id="example-4-use-foreach-to-run-an-rnn-on-a-variable-length-sequence"></span><h3>Example 4: use <code class="docutils literal"><span class="pre">foreach</span></code> to run an RNN on a variable-length sequence<a class="headerlink" href="#example-4-use-foreach-to-run-an-rnn-on-a-variable-length-sequence" title="Permalink to this headline"></a></h3>
<p>Previous examples illustrate <code class="docutils literal"><span class="pre">foreach</span></code> with simple use cases. Here we show an example of processing variable-length sequences with <code class="docutils literal"><span class="pre">foreach</span></code>. The same idea is used by <code class="docutils literal"><span class="pre">dynamic_rnn</span></code> in TensorFlow for processing variable-length sequences.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">DynamicRNNLayer</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cell</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">DynamicRNNLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">prefix</span><span class="o">=</span><span class="n">prefix</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cell</span> <span class="o">=</span> <span class="n">cell</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">begin_state</span><span class="p">,</span> <span class="n">valid_length</span><span class="p">):</span>
<span class="n">states</span> <span class="o">=</span> <span class="n">begin_state</span>
<span class="n">zeros</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">states</span><span class="p">:</span>
<span class="n">zeros</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">s</span><span class="p">))</span>
<span class="c1"># the last state is the iteration number.</span>
<span class="n">states</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">)))</span>
<span class="k">def</span> <span class="nf">loop_body</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="n">cell_states</span> <span class="o">=</span> <span class="n">states</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="c1"># Get the iteration number from the states.</span>
<span class="n">iter_no</span> <span class="o">=</span> <span class="n">states</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">out</span><span class="p">,</span> <span class="n">new_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cell</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">cell_states</span><span class="p">)</span>
<span class="c1"># Copy the old state if we have reached the end of a sequence.</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">state</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">cell_states</span><span class="p">):</span>
<span class="n">new_states</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">broadcast_greater</span><span class="p">(</span><span class="n">valid_length</span><span class="p">,</span> <span class="n">iter_no</span><span class="p">),</span>
<span class="n">new_states</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">state</span><span class="p">)</span>
<span class="n">new_states</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">iter_no</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">new_states</span>
<span class="n">outputs</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">foreach</span><span class="p">(</span><span class="n">loop_body</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">SequenceMask</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="n">valid_length</span><span class="p">,</span>
<span class="n">use_sequence_length</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># the last state is the iteration number. We don't need it.</span>
<span class="k">return</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">states</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">seq_len</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">input_size</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">rnn_data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">input_size</span><span class="p">))</span>
<span class="n">init_states</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="n">valid_length</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)))</span>
<span class="n">lstm</span> <span class="o">=</span> <span class="n">DynamicRNNLayer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">gluon</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">))</span>
<span class="n">lstm</span><span class="o">.</span><span class="n">initialize</span><span class="p">()</span>
<span class="n">res</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm</span><span class="p">(</span><span class="n">rnn_data</span><span class="p">,</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">init_states</span><span class="p">],</span> <span class="n">valid_length</span><span class="p">)</span>
<span class="n">lstm</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
<span class="n">res</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">lstm</span><span class="p">(</span><span class="n">rnn_data</span><span class="p">,</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">init_states</span><span class="p">],</span> <span class="n">valid_length</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="while-loop">
<span id="while-loop"></span><h2>while_loop<a class="headerlink" href="#while-loop" title="Permalink to this headline"></a></h2>
<p><code class="docutils literal"><span class="pre">while_loop</span></code> defines a while loop. It has the following signature:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">while_loop</span><span class="p">(</span><span class="n">cond</span><span class="p">,</span> <span class="n">body</span><span class="p">,</span> <span class="n">loop_vars</span><span class="p">,</span> <span class="n">max_iterations</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span> <span class="o">=></span> <span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
</pre></div>
</div>
<p>Instead of running over the first dimension of an array, <code class="docutils literal"><span class="pre">while_loop</span></code> checks a condition function in every iteration and runs a <code class="docutils literal"><span class="pre">body</span></code> function for computation. The signature of the <code class="docutils literal"><span class="pre">body</span></code> function is defined as follows:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">body</span><span class="p">(</span><span class="n">state1</span><span class="p">,</span> <span class="n">state2</span><span class="p">,</span> <span class="o">...</span><span class="p">)</span> <span class="o">=></span> <span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
</pre></div>
</div>
<p>The inputs of the <code class="docutils literal"><span class="pre">body</span></code> function in <code class="docutils literal"><span class="pre">while_loop</span></code> are a little different from the one in <code class="docutils literal"><span class="pre">foreach</span></code>. It has a variable number of input arguments. Each input argument is a loop variable and the number of arguments is determined by the number of loop variables. The outputs of the <code class="docutils literal"><span class="pre">body</span></code> function also have two parts: <code class="docutils literal"><span class="pre">outputs</span></code> is an array or a list of arrays; <code class="docutils literal"><span class="pre">states</span></code> are loop variables and will be passed to the next iteration as inputs of <code class="docutils literal"><span class="pre">body</span></code>. Like <code class="docutils literal"><span class="pre">foreach</span></code>, both <code class="docutils literal"><span class="pre">outputs</span></code> and <code class="docutils literal"><span class="pre">states</span></code> can be an empty list. <code class="docutils literal"><span class="pre">outputs</span></code> from all iterations are concatenated as the outputs of <code class="docutils literal"><span class="pre">while_loop</span></code>.</p>
<div class="section" id="example-5-scan-with-while-loop">
<span id="example-5-scan-with-while-loop"></span><h3>Example 5: scan with while_loop<a class="headerlink" href="#example-5-scan-with-while-loop" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">while_loop</span></code> is more general than <code class="docutils literal"><span class="pre">foreach</span></code>. We can also use it to iterate over an array and sum all of its values together. In this example, instead of summing over the entire array, we only sum over the first 4 elements.</p>
<p><strong>Note</strong>: the output arrays of the current implementation of <code class="docutils literal"><span class="pre">while_loop</span></code> is determined by <code class="docutils literal"><span class="pre">max_iterations</span></code>. As such, even though the while loop in this example runs 4 iterations, it still outputs an array of 5 elements. The last element in the output array is actually filled with an arbitrary value.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">ScanV2</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">sum</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">state</span> <span class="o">+</span> <span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">return</span> <span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">s</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">sum_cond</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
<span class="k">return</span> <span class="n">i</span> <span class="o"><</span> <span class="mi">4</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">while_loop</span><span class="p">(</span><span class="n">sum_cond</span><span class="p">,</span> <span class="nb">sum</span><span class="p">,</span>
<span class="p">[</span><span class="n">F</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">)),</span> <span class="n">F</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">))],</span> <span class="n">max_iterations</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">state</span>
<span class="n">scan_layer</span> <span class="o">=</span> <span class="n">ScanV2</span><span class="p">()</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">scan_layer</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">[[</span> <span class="mf">0.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">1.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">3.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">6.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.</span><span class="p">]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
<span class="p">[</span>
<span class="p">[</span> <span class="mf">6.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span><span class="p">,</span>
<span class="p">[</span> <span class="mf">4.</span><span class="p">]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">1</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span><span class="p">]</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="cond">
<span id="cond"></span><h2>cond<a class="headerlink" href="#cond" title="Permalink to this headline"></a></h2>
<p><code class="docutils literal"><span class="pre">cond</span></code> defines an if condition. It has the following signature:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">cond</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">then_func</span><span class="p">,</span> <span class="n">else_func</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">cond</span></code> checks <code class="docutils literal"><span class="pre">pred</span></code>, which is a symbol or an NDArray with one element. If its value is true, it calls <code class="docutils literal"><span class="pre">then_func</span></code>. Otherwise, it calls <code class="docutils literal"><span class="pre">else_func</span></code>. The signature of <code class="docutils literal"><span class="pre">then_func</span></code> and <code class="docutils literal"><span class="pre">else_func</span></code> are as follows:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">func</span><span class="p">()</span> <span class="o">=></span> <span class="p">[</span><span class="n">outputs</span><span class="p">]</span>
</pre></div>
</div>
<p><code class="docutils literal"><span class="pre">cond</span></code> requires all outputs from <code class="docutils literal"><span class="pre">then_func</span></code> and <code class="docutils literal"><span class="pre">else_func</span></code> have the same number of Symbols/NDArrays with the same shapes and data types.</p>
<div class="section" id="example-6-skip-rnn-computation-with-cond">
<span id="example-6-skip-rnn-computation-with-cond"></span><h3>Example 6: skip RNN computation with cond<a class="headerlink" href="#example-6-skip-rnn-computation-with-cond" title="Permalink to this headline"></a></h3>
<p>Example 4 shows how to process a batch with sequences of different lengths. It performs computation for all steps but discards some of the computation results.</p>
<p>In this example, we show how to skip computation after we have reached the end of a sequence, whose length is indicated by <code class="docutils literal"><span class="pre">length</span></code>. The code below only works for a batch with one sequence.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">SkipRNNCell</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cell</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">SkipRNNCell</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">prefix</span><span class="o">=</span><span class="n">prefix</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cell</span> <span class="o">=</span> <span class="n">cell</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">length</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">run_rnn</span><span class="p">():</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">cell</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">copy_states</span><span class="p">():</span>
<span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">states</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">cond</span><span class="p">(</span><span class="n">i</span> <span class="o"><</span> <span class="n">length</span><span class="p">,</span> <span class="n">run_rnn</span><span class="p">,</span> <span class="n">copy_states</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">state</span>
<span class="k">class</span> <span class="nc">RNNLayer</span><span class="p">(</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cell</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RNNLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">prefix</span><span class="o">=</span><span class="n">prefix</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cell</span> <span class="o">=</span> <span class="n">SkipRNNCell</span><span class="p">(</span><span class="n">cell</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">length</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">init_states</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">body</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span>
<span class="n">i</span> <span class="o">=</span> <span class="n">states</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">out</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cell</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">length</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">states</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">states</span><span class="p">]</span>
<span class="k">print</span><span class="p">()</span>
<span class="n">out</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">foreach</span><span class="p">(</span><span class="n">body</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="p">[</span><span class="n">F</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">)),</span> <span class="n">init_states</span><span class="p">])</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">state</span>
<span class="n">seq_len</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">input_size</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">rnn_data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">input_size</span><span class="p">))</span>
<span class="n">init_states</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="n">cell</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gluon</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">RNNLayer</span><span class="p">(</span><span class="n">cell</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">initialize</span><span class="p">()</span>
<span class="n">out</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">]),</span> <span class="n">rnn_data</span><span class="p">,</span> <span class="n">init_states</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">rnn_data</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="p">()</span>
<span class="p">[[[</span><span class="o">-</span><span class="mf">1.25296438</span> <span class="mf">0.387312</span> <span class="o">-</span><span class="mf">0.41055229</span><span class="p">]]</span>
<span class="p">[[</span> <span class="mf">1.28453672</span> <span class="mf">0.21001032</span> <span class="o">-</span><span class="mf">0.08666432</span><span class="p">]]</span>
<span class="p">[[</span> <span class="mf">1.46422136</span> <span class="o">-</span><span class="mf">1.30581355</span> <span class="mf">0.9344402</span> <span class="p">]]</span>
<span class="p">[[</span> <span class="mf">0.5380863</span> <span class="o">-</span><span class="mf">0.16038011</span> <span class="mf">0.84187603</span><span class="p">]]</span>
<span class="p">[[</span><span class="o">-</span><span class="mf">1.00553632</span> <span class="mf">3.13221502</span> <span class="o">-</span><span class="mf">0.4358989</span> <span class="p">]]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1x3</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
<span class="p">[[[</span><span class="o">-</span><span class="mf">0.02620504</span> <span class="mf">0.1605694</span> <span class="mf">0.29636264</span><span class="p">]]</span>
<span class="p">[[</span><span class="o">-</span><span class="mf">0.00474182</span> <span class="mf">0.08719197</span> <span class="mf">0.17757624</span><span class="p">]]</span>
<span class="p">[[</span> <span class="mf">0.00631597</span> <span class="mf">0.04674901</span> <span class="mf">0.12468992</span><span class="p">]]</span>
<span class="p">[[</span> <span class="mf">0.</span> <span class="mf">0.</span> <span class="mf">0.</span> <span class="p">]]</span>
<span class="p">[[</span> <span class="mf">0.</span> <span class="mf">0.</span> <span class="mf">0.</span> <span class="p">]]]</span>
<span class="o"><</span><span class="n">NDArray</span> <span class="mi">5</span><span class="n">x1x3</span> <span class="nd">@cpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">></span>
</pre></div>
</div>
<div class="btn-group" role="group">
<div class="download-btn"><a download="ControlFlowTutorial.ipynb" href="ControlFlowTutorial.ipynb"><span class="glyphicon glyphicon-download-alt"></span> ControlFlowTutorial.ipynb</a></div></div></div>
</div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Hybridize Gluon models with control flows.</a><ul>
<li><a class="reference internal" href="#prepare-running-the-code">Prepare running the code</a></li>
<li><a class="reference internal" href="#foreach">foreach</a><ul>
<li><a class="reference internal" href="#example-1-foreach-works-like-map">Example 1: <code class="docutils literal"><span class="pre">foreach</span></code> works like map</a></li>
<li><a class="reference internal" href="#example-2-foreach-works-like-scan">Example 2: <code class="docutils literal"><span class="pre">foreach</span></code> works like scan</a></li>
<li><a class="reference internal" href="#example-3-foreach-with-both-outputs-and-states">Example 3: <code class="docutils literal"><span class="pre">foreach</span></code> with both outputs and states</a></li>
<li><a class="reference internal" href="#example-4-use-foreach-to-run-an-rnn-on-a-variable-length-sequence">Example 4: use <code class="docutils literal"><span class="pre">foreach</span></code> to run an RNN on a variable-length sequence</a></li>
</ul>
</li>
<li><a class="reference internal" href="#while-loop">while_loop</a><ul>
<li><a class="reference internal" href="#example-5-scan-with-while-loop">Example 5: scan with while_loop</a></li>
</ul>
</li>
<li><a class="reference internal" href="#cond">cond</a><ul>
<li><a class="reference internal" href="#example-6-skip-rnn-computation-with-cond">Example 6: skip RNN computation with cond</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script src="../../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>