blob: 9deb569abdc707ce7c41b1787356722002cb8b47 [file] [log] [blame]
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta content="IE=edge" http-equiv="X-UA-Compatible"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<meta content="Generative Adversarial Network (GAN)" property="og:title">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image">
<meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url">
<meta content="Generative Adversarial Network (GAN)" property="og:description"/>
<title>Generative Adversarial Network (GAN) — mxnet documentation</title>
<link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/>
<link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/>
<link href="../../_static/basic.css" rel="stylesheet" type="text/css">
<link href="../../_static/pygments.css" rel="stylesheet" type="text/css">
<link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: '../../',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script>
<script src="../../_static/underscore.js" type="text/javascript"></script>
<script src="../../_static/searchtools_custom.js" type="text/javascript"></script>
<script src="../../_static/doctools.js" type="text/javascript"></script>
<script src="../../_static/selectlang.js" type="text/javascript"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
<script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/1.3.1/searchindex.js"); Search.init();}); </script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new
Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-96378503-1', 'auto');
ga('send', 'pageview');
</script>
<!-- -->
<!-- <script type="text/javascript" src="../../_static/jquery.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/underscore.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="../../_static/doctools.js"></script> -->
<!-- -->
<!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> -->
<!-- -->
<link href="../../genindex.html" rel="index" title="Index">
<link href="../../search.html" rel="search" title="Search"/>
<link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/>
</link></link></link></meta></meta></meta></head>
<body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document">
<div class="content-block"><div class="navbar navbar-fixed-top">
<div class="container" id="navContainer">
<div class="innder" id="header-inner">
<h1 id="logo-wrap">
<a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a>
</h1>
<nav class="nav-bar" id="main-nav">
<a class="main-nav-link" href="/versions/1.3.1/install/index.html">Install</a>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/scala/index.html">Scala</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-docs">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs">
<li><a class="main-nav-link" href="/versions/1.3.1/faq/index.html">FAQ</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/index.html">Tutorials</a>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.3.1/example">Examples</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/architecture/index.html">Architecture</a></li>
<li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/model_zoo/index.html">Model Zoo</a></li>
<li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li>
</li></ul>
</span>
<span id="dropdown-menu-position-anchor-community">
<a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community">
<li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li>
<li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/1.3.1">Github</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/community/contribute.html">Contribute</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/community/ecosystem.html">Ecosystem</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/community/powered_by.html">Powered By</a></li>
</ul>
</span>
<span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">1.3.1<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav>
<script> function getRootPath(){ return "../../" } </script>
<div class="burgerIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"></a>
<ul class="dropdown-menu" id="burgerMenu">
<li><a href="/versions/1.3.1/install/index.html">Install</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/index.html">Tutorials</a></li>
<li class="dropdown-submenu dropdown">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a>
<ul class="dropdown-menu navbar-menu" id="package-dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/tutorials/gluon/gluon.html">About</a></li>
<li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li>
<li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li>
<li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a>
<ul class="dropdown-menu">
<li><a class="main-nav-link" href="/versions/1.3.1/api/python/index.html">Python</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/c++/index.html">C++</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/clojure/index.html">Clojure</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/julia/index.html">Julia</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/perl/index.html">Perl</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/r/index.html">R</a></li>
<li><a class="main-nav-link" href="/versions/1.3.1/api/scala/index.html">Scala</a></li>
</ul>
</li>
<li class="dropdown-submenu">
<a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a>
<ul class="dropdown-menu">
<li><a href="/versions/1.3.1/faq/index.html" tabindex="-1">FAQ</a></li>
<li><a href="/versions/1.3.1/tutorials/index.html" tabindex="-1">Tutorials</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.3.1/example" tabindex="-1">Examples</a></li>
<li><a href="/versions/1.3.1/architecture/index.html" tabindex="-1">Architecture</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li>
<li><a href="/versions/1.3.1/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li>
<li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li>
</ul>
</li>
<li class="dropdown-submenu dropdown">
<a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a>
<ul class="dropdown-menu">
<li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li>
<li><a href="https://github.com/apache/incubator-mxnet/tree/1.3.1" tabindex="-1">Github</a></li>
<li><a href="/versions/1.3.1/community/contribute.html" tabindex="-1">Contribute</a></li>
<li><a href="/versions/1.3.1/community/ecosystem.html" tabindex="-1">Ecosystem</a></li>
<li><a href="/versions/1.3.1/community/powered_by.html" tabindex="-1">Powered By</a></li>
</ul>
</li>
<li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">1.3.1</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul>
</div>
<div class="plusIcon dropdown">
<a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a>
<ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul>
</div>
<div id="search-input-wrap">
<form action="../../search.html" autocomplete="off" class="" method="get" role="search">
<div class="form-group inner-addon left-addon">
<i class="glyphicon glyphicon-search"></i>
<input class="form-control" name="q" placeholder="Search" type="text"/>
</div>
<input name="check_keywords" type="hidden" value="yes">
<input name="area" type="hidden" value="default"/>
</input></form>
<div id="search-preview"></div>
</div>
<div id="searchIcon">
<span aria-hidden="true" class="glyphicon glyphicon-search"></span>
</div>
<!-- <div id="lang-select-wrap"> -->
<!-- <label id="lang-select-label"> -->
<!-- <\!-- <i class="fa fa-globe"></i> -\-> -->
<!-- <span></span> -->
<!-- </label> -->
<!-- <select id="lang-select"> -->
<!-- <option value="en">Eng</option> -->
<!-- <option value="zh">中文</option> -->
<!-- </select> -->
<!-- </div> -->
<!-- <a id="mobile-nav-toggle">
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
<span class="mobile-nav-toggle-bar"></span>
</a> -->
</div>
</div>
</div>
<script type="text/javascript">
$('body').css('background', 'white');
</script>
<div class="container">
<div class="row">
<div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index.html">HowTo Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../community/contribute.html">Community</a></li>
</ul>
</div>
</div>
<div class="content">
<div class="page-tracker"></div>
<div class="section" id="generative-adversarial-network-gan">
<span id="generative-adversarial-network-gan"></span><h1>Generative Adversarial Network (GAN)<a class="headerlink" href="#generative-adversarial-network-gan" title="Permalink to this headline"></a></h1>
<p>Generative Adversarial Networks (GANs) are a class of algorithms used in unsupervised learning - you don’t need labels for your dataset in order to train a GAN.</p>
<p>The GAN framework is composed of two neural networks: a Generator network and a Discriminator network.</p>
<p>The Generator’s job is to take a set of random numbers and produce the data (such as images or text).</p>
<p>The Discriminator then takes in that data as well as samples of that data from a dataset and tries to determine if it is “fake” (created by the Generator network) or “real” (from the original dataset).</p>
<p>During training, the two networks play a game against each other. The Generator tries to create realistic data, so that it can fool the Discriminator into thinking that the data it generated is from the original dataset. At the same time, the Discriminator tries to not be fooled - it learns to become better at determining if data is real or fake.</p>
<p>Since the two networks are fighting in this game, they can be seen as as adversaries, which is where the term “Generative Adversarial Network” comes from.</p>
<div class="section" id="deep-convolutional-generative-adversarial-networks">
<span id="deep-convolutional-generative-adversarial-networks"></span><h2>Deep Convolutional Generative Adversarial Networks<a class="headerlink" href="#deep-convolutional-generative-adversarial-networks" title="Permalink to this headline"></a></h2>
<p>This tutorial takes a look at Deep Convolutional Generative Adversarial Networks (DCGAN), which combines Convolutional Neural Networks (CNNs) and GANs.</p>
<p>We will create a DCGAN that is able to create images of handwritten digits from random numbers. The tutorial uses the neural net architecture and guidelines outlined in <a class="reference external" href="https://arxiv.org/abs/1511.06434">this paper</a>, and the MNIST dataset.</p>
</div>
<div class="section" id="how-to-use-this-tutorial">
<span id="how-to-use-this-tutorial"></span><h2>How to Use This Tutorial<a class="headerlink" href="#how-to-use-this-tutorial" title="Permalink to this headline"></a></h2>
<p>You can use this tutorial by executing each snippet of python code in order as it appears in the tutorial.</p>
<ol class="simple">
<li>The first net is the “Generator” and creates images of handwritten digits from random numbers.</li>
<li>The second net is the “Discriminator” and determines if the image created by the Generator is real (a realistic looking image of handwritten digits) or fake (an image that does not look like it is from the original dataset).</li>
</ol>
<p>Apart from creating a DCGAN, you’ll also learn:</p>
<ul class="simple">
<li>How to manipulate and iterate through batches of image data that you can feed into your neural network.</li>
<li>How to create a custom MXNet data iterator that generates random numbers from a normal distribution.</li>
<li>How to create a custom training process in MXNet, using lower level functions from the MXNet Module API such as .bind() .forward() and .backward(). The training process for a DCGAN is more complex than many other neural networks, so we need to use these functions instead of using the higher level .fit() function.</li>
<li>How to visualize images as they are going through the training process</li>
</ul>
</div>
<div class="section" id="prerequisites">
<span id="prerequisites"></span><h2>Prerequisites<a class="headerlink" href="#prerequisites" title="Permalink to this headline"></a></h2>
<p>This tutorial assumes you are familiar with the concepts of CNNs and have implemented one in MXNet. You should also be familiar with the concept of logistic regression. Having a basic understanding of MXNet data iterators helps, since we will create a custom data iterator to iterate though random numbers as inputs to the Generator network.</p>
<p>This example is designed to be trained on a single GPU. Training this network on CPU can be slow, so it’s recommended that you use a GPU for training.</p>
<p>To complete this tutorial, you need:</p>
<ul class="simple">
<li>MXNet</li>
<li>Python, and the following libraries for Python:<ul>
<li>Numpy - for matrix math</li>
<li>OpenCV - for image manipulation</li>
<li>Matplotlib - to visualize the output</li>
</ul>
</li>
</ul>
</div>
<div class="section" id="the-data">
<span id="the-data"></span><h2>The Data<a class="headerlink" href="#the-data" title="Permalink to this headline"></a></h2>
<p>We need two pieces of data to train the DCGAN:
1. Images of handwritten digits from the MNIST dataset
2. Random numbers from a normal distribution</p>
<p>The Generator network will use the random numbers as the input to produce the images of handwritten digits, and the Discriminator network will use images of handwritten digits from the MNIST dataset to determine if images produced by the Generator are realistic.</p>
<p>The MNIST dataset contains 70,000 images of handwritten digits. Each image is 28x28 pixels in size. To create random numbers, we’re going to create a custom MXNet data iterator, which will returns random numbers from a normal distribution as we need then.</p>
</div>
<div class="section" id="prepare-the-data">
<span id="prepare-the-data"></span><h2>Prepare the Data<a class="headerlink" href="#prepare-the-data" title="Permalink to this headline"></a></h2>
<div class="section" id="preparing-the-mnsit-dataset">
<span id="preparing-the-mnsit-dataset"></span><h3>1. Preparing the MNSIT dataset<a class="headerlink" href="#preparing-the-mnsit-dataset" title="Permalink to this headline"></a></h3>
<p>Let us start by preparing the handwritten digits from the MNIST dataset.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span>
<span class="n">mnist_train</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">vision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">mnist_test</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">vision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># The downloaded data is of type `Dataset` which are</span>
<span class="c1"># Well suited to work with the new Gluon interface but less</span>
<span class="c1"># With the older symbol API, used in this tutorial. </span>
<span class="c1"># Therefore we convert them to numpy array first</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">70000</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">mnist_train</span><span class="p">):</span>
<span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()[:,:,</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">mnist_test</span><span class="p">):</span>
<span class="n">X</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">mnist_train</span><span class="p">)</span><span class="o">+</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()[:,:,</span><span class="mi">0</span><span class="p">]</span>
</pre></div>
</div>
<p>Next, we will randomize the handwritten digits by using numpy to create random permutations on the dataset on the rows (images). Every image in the dataset is arranged into a 28x28 grid, where each cell in the grid represents 1 pixel of the image.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1">#Use a seed so that we get the same random permutation each time</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">p</span><span class="p">]</span>
</pre></div>
</div>
<p>Since the DCGAN that we’re creating takes in a 64x64 image as the input, we will use OpenCV to resize the each 28x28 image to 64x64 images:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">cv2</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="mi">64</span><span class="p">,</span><span class="mi">64</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">X</span><span class="p">])</span>
</pre></div>
</div>
<p>Each pixel in the 64x64 image is represented by a number between 0-255, that represents the intensity of the pixel. However, we want to input numbers between -1 and 1 into the DCGAN, as suggested by the <a class="reference external" href="https://arxiv.org/abs/1511.06434">research paper</a>. To rescale the pixel values, we will divide it by (255/2). This changes the scale to 0-2. We then subtract by 1 to get them in the range of -1 to 1.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">copy</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="mf">255.0</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span>
</pre></div>
</div>
<p>Ultimately, images are fed into the neural net through a 70000x3x64x64 array but they are currently in a 70000x64x64 array. We need to add 3 channels to the images. Typically, when we are working with the images, the 3 channels represent the red, green, and blue (RGB) components of each image. Since the MNIST dataset is grayscale, we only need 1 channel to represent the dataset. We will pad the other channels with 0’s:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">70000</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">))</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</pre></div>
</div>
<p>Finally, we will put the images into MXNet’s NDArrayIter, which will allow MXNet to easily iterate through the images during training. We will also split them up into batches of 64 images each. Every time we iterate, we will get a 4 dimensional array with size (64, 3, 64, 64), representing a batch of 64 images.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">image_iter</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">NDArrayIter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="preparing-random-numbers">
<span id="preparing-random-numbers"></span><h3>2. Preparing Random Numbers<a class="headerlink" href="#preparing-random-numbers" title="Permalink to this headline"></a></h3>
<p>We need to input random numbers from a normal distribution to the Generator network, so we will create an MXNet DataIter that produces random numbers for each training batch. The DataIter is the base class of MXNet’s Data Loading API. Below, we create a class called RandIter which is a subclass of DataIter. We use MXNet’s built-in mx.random.normal function to return the random numbers from a normal distribution during the iteration.</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">RandIter</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">DataIter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">ndim</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ndim</span> <span class="o">=</span> <span class="n">ndim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">provide_data</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">'rand'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">ndim</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">provide_label</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">def</span> <span class="nf">iter_next</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">True</span>
<span class="k">def</span> <span class="nf">getdata</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1">#Returns random numbers from a gaussian (normal) distribution</span>
<span class="c1">#with mean=0 and standard deviation = 1</span>
<span class="k">return</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ndim</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))]</span>
</pre></div>
</div>
<p>When we initialize the RandIter, we need to provide two numbers: the batch size and how many random numbers we want in order to produce a single image from. This number is referred to as Z, and we will set this to 100. This value comes from the research paper on the topic. Every time we iterate and get a batch of random numbers, we will get a 4 dimensional array with shape: (batch_size, Z, 1, 1), which in the example is (64, 100, 1, 1).</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">Z</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">rand_iter</span> <span class="o">=</span> <span class="n">RandIter</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">Z</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="create-the-model">
<span id="create-the-model"></span><h2>Create the Model<a class="headerlink" href="#create-the-model" title="Permalink to this headline"></a></h2>
<p>The model has two networks that we will train together - the Generator network and the Discriminator network.</p>
<div class="section" id="the-generator">
<span id="the-generator"></span><h3>The Generator<a class="headerlink" href="#the-generator" title="Permalink to this headline"></a></h3>
<p>Let us start off by defining the Generator network, which uses Deconvolution layers (also called as fractionally strided layers) to generate an image form random numbers :</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">no_bias</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">fix_gamma</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">epsilon</span> <span class="o">=</span> <span class="mf">1e-5</span> <span class="o">+</span> <span class="mf">1e-12</span>
<span class="n">rand</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'rand'</span><span class="p">)</span>
<span class="n">g1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Deconvolution</span><span class="p">(</span><span class="n">rand</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'g1'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">gbn1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">g1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gbn1'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">gact1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">gbn1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gact1'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
<span class="n">g2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Deconvolution</span><span class="p">(</span><span class="n">gact1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'g2'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">gbn2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">g2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gbn2'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">gact2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">gbn2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gact2'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
<span class="n">g3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Deconvolution</span><span class="p">(</span><span class="n">gact2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'g3'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">gbn3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">g3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gbn3'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">gact3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">gbn3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gact3'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
<span class="n">g4</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Deconvolution</span><span class="p">(</span><span class="n">gact3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'g4'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">gbn4</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">g4</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gbn4'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">gact4</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">gbn4</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gact4'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span>
<span class="n">g5</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Deconvolution</span><span class="p">(</span><span class="n">gact4</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'g5'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">generatorSymbol</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">g5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'gact5'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'tanh'</span><span class="p">)</span>
</pre></div>
</div>
<p>The Generator image starts with random numbers that will be obtained from the RandIter we created earlier, so we created the rand variable for this input.
We then start creating the model starting with a Deconvolution layer (sometimes called ‘fractionally strided layer’). We apply batch normalization and ReLU activation after the Deconvolution layer.</p>
<p>We repeat this process 4 times, applying a (2,2) stride and (1,1) pad at each Deconvolution layer, which doubles the size of the image at each layer. By creating these layers, the Generator network will have to learn to upsample the input vector of random numbers, Z at each layer, so that network output a final image. We also reduce by half the number of filters at each layer, reducing dimensionality at each layer. Ultimately, the output layer is a 64x64x3 layer, representing the size and channels of the image. We use tanh activation instead of relu on the last layer, as recommended by the research on DCGANs. The output of neurons in the final gout layer represent the pixels of generated image.</p>
<p>Notice we used 3 parameters to help us create the model: no_bias, fixed_gamma, and epsilon. Neurons in the network won’t have a bias added to them, this seems to work better in practice for the DCGAN. In the batch norm layer, we set fixed_gamma=True, which means gamma=1 for all of the batch norm layers. epsilon is a small number that gets added to the batch norm so that we don’t end up dividing by zero. By default, CuDNN requires that this number is greater than 1e-5, so we add a small number to this value, ensuring this values stays small.</p>
</div>
<div class="section" id="the-discriminator">
<span id="the-discriminator"></span><h3>The Discriminator<a class="headerlink" href="#the-discriminator" title="Permalink to this headline"></a></h3>
<p>Let us now create the Discriminator network, which will take in images of handwritten digits from the MNIST dataset and images created by the Generator network:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'data'</span><span class="p">)</span>
<span class="n">d1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'d1'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">dact1</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">d1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dact1'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'leaky'</span><span class="p">,</span> <span class="n">slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
<span class="n">d2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">dact1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'d2'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">dbn2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">d2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dbn2'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">dact2</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">dbn2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dact2'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'leaky'</span><span class="p">,</span> <span class="n">slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
<span class="n">d3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">dact2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'d3'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">dbn3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">d3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dbn3'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">dact3</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">dbn3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dact3'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'leaky'</span><span class="p">,</span> <span class="n">slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
<span class="n">d4</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">dact3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'d4'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">pad</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">dbn4</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">d4</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dbn4'</span><span class="p">,</span> <span class="n">fix_gamma</span><span class="o">=</span><span class="n">fix_gamma</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">epsilon</span><span class="p">)</span>
<span class="n">dact4</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">dbn4</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dact4'</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s1">'leaky'</span><span class="p">,</span> <span class="n">slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
<span class="n">d5</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Convolution</span><span class="p">(</span><span class="n">dact4</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'d5'</span><span class="p">,</span> <span class="n">kernel</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="n">num_filter</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">no_bias</span><span class="o">=</span><span class="n">no_bias</span><span class="p">)</span>
<span class="n">d5</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(</span><span class="n">d5</span><span class="p">)</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'label'</span><span class="p">)</span>
<span class="n">discriminatorSymbol</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">sym</span><span class="o">.</span><span class="n">LogisticRegressionOutput</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">d5</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'dloss'</span><span class="p">)</span>
</pre></div>
</div>
<p>We start off by creating the data variable, which is used to hold the input images to the Discriminator.</p>
<p>The Discriminator then goes through a series of 5 convolutional layers, each with a 4x4 kernel, 2x2 stride, and 1x1 pad. These layers half the size of the image (which starts at 64x64) at each convolutional layer. The model also increases dimensionality at each layer by doubling the number of filters per convolutional layer, starting at 128 filters and ending at 1024 filters before we flatten the output.</p>
<p>At the final convolution, we flatten the neural net to get one number as the final output of Discriminator network. This number is the probability that the image is real, as determined by the Discriminator. We use logistic regression to determine this probability. When we pass in “real” images from the MNIST dataset, we can label these as 1 and we can label the “fake” images from the Generator net as 0 to perform logistic regression on the Discriminator network.</p>
</div>
<div class="section" id="prepare-the-models-using-the-module-api">
<span id="prepare-the-models-using-the-module-api"></span><h3>Prepare the models using the Module API<a class="headerlink" href="#prepare-the-models-using-the-module-api" title="Permalink to this headline"></a></h3>
<p>So far we have defined a MXNet Symbol for both the Generator and the Discriminator network. Before we can train the model, we need to bind these symbols using the Module API, which creates the computation graph for the models. It also allows us to decide how we want to initialize the model and what type of optimizer we want to use. Let us set up the Module for both the networks:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1">#Hyper-parameters</span>
<span class="n">sigma</span> <span class="o">=</span> <span class="mf">0.02</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.0002</span>
<span class="n">beta1</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="c1"># Define the compute context, use GPU if available</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">()</span> <span class="k">if</span> <span class="n">mx</span><span class="o">.</span><span class="n">test_utils</span><span class="o">.</span><span class="n">list_gpus</span><span class="p">()</span> <span class="k">else</span> <span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="c1">#=============Generator Module=============</span>
<span class="n">generator</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">generatorSymbol</span><span class="p">,</span> <span class="n">data_names</span><span class="o">=</span><span class="p">(</span><span class="s1">'rand'</span><span class="p">,),</span> <span class="n">label_names</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">context</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">generator</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">data_shapes</span><span class="o">=</span><span class="n">rand_iter</span><span class="o">.</span><span class="n">provide_data</span><span class="p">)</span>
<span class="n">generator</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">initializer</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="n">sigma</span><span class="p">))</span>
<span class="n">generator</span><span class="o">.</span><span class="n">init_optimizer</span><span class="p">(</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span>
<span class="s1">'learning_rate'</span><span class="p">:</span> <span class="n">lr</span><span class="p">,</span>
<span class="s1">'beta1'</span><span class="p">:</span> <span class="n">beta1</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">mods</span> <span class="o">=</span> <span class="p">[</span><span class="n">generator</span><span class="p">]</span>
<span class="c1"># =============Discriminator Module=============</span>
<span class="n">discriminator</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">mod</span><span class="o">.</span><span class="n">Module</span><span class="p">(</span><span class="n">symbol</span><span class="o">=</span><span class="n">discriminatorSymbol</span><span class="p">,</span> <span class="n">data_names</span><span class="o">=</span><span class="p">(</span><span class="s1">'data'</span><span class="p">,),</span> <span class="n">label_names</span><span class="o">=</span><span class="p">(</span><span class="s1">'label'</span><span class="p">,),</span> <span class="n">context</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">data_shapes</span><span class="o">=</span><span class="n">image_iter</span><span class="o">.</span><span class="n">provide_data</span><span class="p">,</span>
<span class="n">label_shapes</span><span class="o">=</span><span class="p">[(</span><span class="s1">'label'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,))],</span>
<span class="n">inputs_need_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">init_params</span><span class="p">(</span><span class="n">initializer</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="n">sigma</span><span class="p">))</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">init_optimizer</span><span class="p">(</span>
<span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span>
<span class="n">optimizer_params</span><span class="o">=</span><span class="p">{</span>
<span class="s1">'learning_rate'</span><span class="p">:</span> <span class="n">lr</span><span class="p">,</span>
<span class="s1">'beta1'</span><span class="p">:</span> <span class="n">beta1</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">mods</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">discriminator</span><span class="p">)</span>
</pre></div>
</div>
<p>First, we create Modules for the networks and then bind the symbols that we’ve created in the previous steps to the modules.
We use rand_iter.provide_data as the data_shape to bind the Generator network. This means that as we iterate though batches of the data on the Generator Module, the RandIter will provide us with random numbers to feed the Module using it’s provide_data function.</p>
<p>Similarly, we bind the Discriminator Module to image_iter.provide_data, which gives us images from MNIST from the NDArrayIter we had set up earlier, called image_iter.</p>
<p>Notice that we are using the Normal Initialization, with the hyperparameter sigma=0.02. This means the weight initializations for the neurons in the networks will be random numbers from a Gaussian (normal) distribution with a mean of 0 and a standard deviation of 0.02.</p>
<p>We also use the Adam optimizer for gradient decent. We’ve set up two hyperparameters, lr and beta1 based on the values used in the DCGAN paper. We’re using a single gpu, gpu(0) for training. Set the context to cpu() if you do not have a GPU on your machine.</p>
</div>
<div class="section" id="visualizing-the-training">
<span id="visualizing-the-training"></span><h3>Visualizing The Training<a class="headerlink" href="#visualizing-the-training" title="Permalink to this headline"></a></h3>
<p>Before we train the model, let us set up some helper functions that will help visualize what the Generator is producing, compared to what the real image is:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="c1">#Takes the images in the batch and arranges them in an array so that they can be</span>
<span class="c1">#Plotted using matplotlib</span>
<span class="k">def</span> <span class="nf">fill_buf</span><span class="p">(</span><span class="n">buf</span><span class="p">,</span> <span class="n">num_images</span><span class="p">,</span> <span class="n">img</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="n">width</span> <span class="o">=</span> <span class="n">buf</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">/</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">height</span> <span class="o">=</span> <span class="n">buf</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">/</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">img_width</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_images</span><span class="o">%</span><span class="n">width</span><span class="p">)</span><span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">img_hight</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_images</span><span class="o">/</span><span class="n">height</span><span class="p">)</span><span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">buf</span><span class="p">[</span><span class="n">img_hight</span><span class="p">:</span><span class="n">img_hight</span><span class="o">+</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">img_width</span><span class="p">:</span><span class="n">img_width</span><span class="o">+</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">img</span>
<span class="c1">#Plots two images side by side using matplotlib</span>
<span class="k">def</span> <span class="nf">visualize</span><span class="p">(</span><span class="n">fake</span><span class="p">,</span> <span class="n">real</span><span class="p">):</span>
<span class="c1">#64x3x64x64 to 64x64x64x3</span>
<span class="n">fake</span> <span class="o">=</span> <span class="n">fake</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="c1">#Pixel values from 0-255</span>
<span class="n">fake</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">((</span><span class="n">fake</span><span class="o">+</span><span class="mf">1.0</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="mf">255.0</span><span class="o">/</span><span class="mf">2.0</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="c1">#Repeat for real image</span>
<span class="n">real</span> <span class="o">=</span> <span class="n">real</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">real</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">((</span><span class="n">real</span><span class="o">+</span><span class="mf">1.0</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="mf">255.0</span><span class="o">/</span><span class="mf">2.0</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="c1">#Create buffer array that will hold all the images in the batch</span>
<span class="c1">#Fill the buffer so to arrange all images in the batch onto the buffer array</span>
<span class="n">n</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">fake</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
<span class="n">fbuff</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">int</span><span class="p">(</span><span class="n">n</span><span class="o">*</span><span class="n">fake</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">n</span><span class="o">*</span><span class="n">fake</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">fake</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">])),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">img</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">fake</span><span class="p">):</span>
<span class="n">fill_buf</span><span class="p">(</span><span class="n">fbuff</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">img</span><span class="p">,</span> <span class="n">fake</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">])</span>
<span class="n">rbuff</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">int</span><span class="p">(</span><span class="n">n</span><span class="o">*</span><span class="n">real</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">n</span><span class="o">*</span><span class="n">real</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">real</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">])),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">img</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">real</span><span class="p">):</span>
<span class="n">fill_buf</span><span class="p">(</span><span class="n">rbuff</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">img</span><span class="p">,</span> <span class="n">real</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">])</span>
<span class="c1">#Create a matplotlib figure with two subplots: one for the real and the other for the fake</span>
<span class="c1">#fill each plot with the buffer array, which creates the image</span>
<span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">()</span>
<span class="n">ax1</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ax1</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">fbuff</span><span class="p">)</span>
<span class="n">ax2</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">)</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">rbuff</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="fit-the-model">
<span id="fit-the-model"></span><h2>Fit the Model<a class="headerlink" href="#fit-the-model" title="Permalink to this headline"></a></h2>
<p>Training the DCGAN is a complex process that requires multiple steps.
To fit the model, for every batch of data in the MNIST dataset:</p>
<ol class="simple">
<li>Use the Z vector, which contains the random numbers to do a forward pass through the Generator network. This outputs the “fake” image, since it is created from the Generator.</li>
<li>Use the fake image as the input to do a forward and backward pass through the Discriminator network. We set the labels for logistic regression to 0 to represent that this is a fake image. This trains the Discriminator to learn what a fake image looks like. We save the gradient produced in backpropagation for the next step.</li>
<li>Do a forward and backward pass through the Discriminator using a real image from the MNIST dataset. The label for logistic regression will now be 1 to represent the real images, so the Discriminator can learn to recognize a real image.</li>
<li>Update the Discriminator by adding the result of the gradient generated during backpropagation on the fake image with the gradient from backpropagation on the real image.</li>
<li>Now that the Discriminator has been updated for the this data batch, we still need to update the Generator. First, do a forward and backwards pass with the same data batch on the updated Discriminator, to produce a new gradient. Use the new gradient to do a backwards pass</li>
</ol>
<p>Here is the main training loop for the DCGAN:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># =============train===============</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'Training...'</span><span class="p">)</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span>
<span class="n">image_iter</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">image_iter</span><span class="p">):</span>
<span class="c1">#Get a batch of random numbers to generate an image from the generator</span>
<span class="n">rbatch</span> <span class="o">=</span> <span class="n">rand_iter</span><span class="o">.</span><span class="n">next</span><span class="p">()</span>
<span class="c1">#Forward pass on training batch</span>
<span class="n">generator</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">rbatch</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1">#Output of training batch is the 64x64x3 image</span>
<span class="n">outG</span> <span class="o">=</span> <span class="n">generator</span><span class="o">.</span><span class="n">get_outputs</span><span class="p">()</span>
<span class="c1">#Pass the generated (fake) image through the discriminator, and save the gradient</span>
<span class="c1">#Label (for logistic regression) is an array of 0's since this image is fake</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="c1">#Forward pass on the output of the discriminator network</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">DataBatch</span><span class="p">(</span><span class="n">outG</span><span class="p">,</span> <span class="p">[</span><span class="n">label</span><span class="p">]),</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1">#Do the backward pass and save the gradient</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">gradD</span> <span class="o">=</span> <span class="p">[[</span><span class="n">grad</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">grad</span><span class="o">.</span><span class="n">context</span><span class="p">)</span> <span class="k">for</span> <span class="n">grad</span> <span class="ow">in</span> <span class="n">grads</span><span class="p">]</span> <span class="k">for</span> <span class="n">grads</span> <span class="ow">in</span> <span class="n">discriminator</span><span class="o">.</span><span class="n">_exec_group</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">]</span>
<span class="c1">#Pass a batch of real images from MNIST through the discriminator</span>
<span class="c1">#Set the label to be an array of 1's because these are the real images</span>
<span class="n">label</span><span class="p">[:]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">batch</span><span class="o">.</span><span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="n">label</span><span class="p">]</span>
<span class="c1">#Forward pass on a batch of MNIST images</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c1">#Do the backward pass and add the saved gradient from the fake images to the gradient</span>
<span class="c1">#generated by this backwards pass on the real images</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="k">for</span> <span class="n">gradsr</span><span class="p">,</span> <span class="n">gradsf</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">discriminator</span><span class="o">.</span><span class="n">_exec_group</span><span class="o">.</span><span class="n">grad_arrays</span><span class="p">,</span> <span class="n">gradD</span><span class="p">):</span>
<span class="k">for</span> <span class="n">gradr</span><span class="p">,</span> <span class="n">gradf</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">gradsr</span><span class="p">,</span> <span class="n">gradsf</span><span class="p">):</span>
<span class="n">gradr</span> <span class="o">+=</span> <span class="n">gradf</span>
<span class="c1">#Update gradient on the discriminator</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">update</span><span class="p">()</span>
<span class="c1">#Now that we've updated the discriminator, let's update the generator</span>
<span class="c1">#First do a forward pass and backwards pass on the newly updated discriminator</span>
<span class="c1">#With the current batch</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">DataBatch</span><span class="p">(</span><span class="n">outG</span><span class="p">,</span> <span class="p">[</span><span class="n">label</span><span class="p">]),</span> <span class="n">is_train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="c1">#Get the input gradient from the backwards pass on the discriminator,</span>
<span class="c1">#and use it to do the backwards pass on the generator</span>
<span class="n">diffD</span> <span class="o">=</span> <span class="n">discriminator</span><span class="o">.</span><span class="n">get_input_grads</span><span class="p">()</span>
<span class="n">generator</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">diffD</span><span class="p">)</span>
<span class="c1">#Update the gradients on the generator</span>
<span class="n">generator</span><span class="o">.</span><span class="n">update</span><span class="p">()</span>
<span class="c1">#Increment to the next batch, printing every 50 batches</span>
<span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="mi">50</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'epoch:'</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="s1">'iter:'</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
<span class="k">print</span>
<span class="k">print</span><span class="p">(</span><span class="s2">" From generator: From MNIST:"</span><span class="p">)</span>
<span class="n">visualize</span><span class="p">(</span><span class="n">outG</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">(),</span> <span class="n">batch</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span>
</pre></div>
</div>
<p>This will train the GAN network and visualize the progress that we are making as the networks are trained. After every 25 iterations, we are calling the visualize function that we created earlier, which plots the intermediate results.</p>
<p>The plot on the left will represent what the Generator created (the fake image) in the most recent iteration. The plot on the right will represent the Original (real) image from the MNIST dataset that was inputted to the Discriminator on the same iteration.</p>
<p>As the training goes on, the Generator becomes better at generating realistic images. You can see this happening since the images on the left becomes closer to the original dataset with each iteration.</p>
</div>
<div class="section" id="summary">
<span id="summary"></span><h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline"></a></h2>
<p>We have now successfully used Apache MXNet to train a Deep Convolutional Generative Adversarial Neural Networks (DCGAN) using the MNIST dataset.</p>
<p>As a result, we have created two neural nets: a Generator, which is able to create images of handwritten digits from random numbers, and a Discriminator, which is able to take an image and determine if it is an image of handwritten digits.</p>
<p>Along the way, we have learned how to do the image manipulation and visualization that is associated with the training of deep neural nets. We have also learned how to use MXNet’s Module APIs to perform advanced model training functionality to fit the model.</p>
</div>
<div class="section" id="acknowledgements">
<span id="acknowledgements"></span><h2>Acknowledgements<a class="headerlink" href="#acknowledgements" title="Permalink to this headline"></a></h2>
<p>This tutorial is based on <a class="reference external" href="https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py">MXNet DCGAN codebase</a>,
<a class="reference external" href="https://arxiv.org/abs/1406.2661">The original paper on GANs</a>, as well as <a class="reference external" href="https://arxiv.org/abs/1511.06434">this paper on deep convolutional GANs</a>.</p>
<div class="btn-group" role="group">
<div class="download-btn"><a download="gan.ipynb" href="gan.ipynb"><span class="glyphicon glyphicon-download-alt"></span> gan.ipynb</a></div></div></div>
</div>
</div>
</div>
<div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation">
<div class="sphinxsidebarwrapper">
<h3><a href="../../index.html">Table Of Contents</a></h3>
<ul>
<li><a class="reference internal" href="#">Generative Adversarial Network (GAN)</a><ul>
<li><a class="reference internal" href="#deep-convolutional-generative-adversarial-networks">Deep Convolutional Generative Adversarial Networks</a></li>
<li><a class="reference internal" href="#how-to-use-this-tutorial">How to Use This Tutorial</a></li>
<li><a class="reference internal" href="#prerequisites">Prerequisites</a></li>
<li><a class="reference internal" href="#the-data">The Data</a></li>
<li><a class="reference internal" href="#prepare-the-data">Prepare the Data</a><ul>
<li><a class="reference internal" href="#preparing-the-mnsit-dataset">1. Preparing the MNSIT dataset</a></li>
<li><a class="reference internal" href="#preparing-random-numbers">2. Preparing Random Numbers</a></li>
</ul>
</li>
<li><a class="reference internal" href="#create-the-model">Create the Model</a><ul>
<li><a class="reference internal" href="#the-generator">The Generator</a></li>
<li><a class="reference internal" href="#the-discriminator">The Discriminator</a></li>
<li><a class="reference internal" href="#prepare-the-models-using-the-module-api">Prepare the models using the Module API</a></li>
<li><a class="reference internal" href="#visualizing-the-training">Visualizing The Training</a></li>
</ul>
</li>
<li><a class="reference internal" href="#fit-the-model">Fit the Model</a></li>
<li><a class="reference internal" href="#summary">Summary</a></li>
<li><a class="reference internal" href="#acknowledgements">Acknowledgements</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div><div class="footer">
<div class="section-disclaimer">
<div class="container">
<div>
<img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/>
<p>
Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p>
<p>
"Copyright © 2017-2018, The Apache Software Foundation
Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation."
</p>
</div>
</div>
</div>
</div> <!-- pagename != index -->
</div>
<script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="../../_static/js/sidebar.js" type="text/javascript"></script>
<script src="../../_static/js/search.js" type="text/javascript"></script>
<script src="../../_static/js/navbar.js" type="text/javascript"></script>
<script src="../../_static/js/clipboard.min.js" type="text/javascript"></script>
<script src="../../_static/js/copycode.js" type="text/javascript"></script>
<script src="../../_static/js/page.js" type="text/javascript"></script>
<script src="../../_static/js/docversion.js" type="text/javascript"></script>
<script type="text/javascript">
$('body').ready(function () {
$('body').css('visibility', 'visible');
});
</script>
</body>
</html>