blob: 483ba566b266e677b9748d1555119b48ce4008ac [file] [log] [blame]
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta charset="utf-8" />
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<meta http-equiv="x-ua-compatible" content="ie=edge">
<style>
.dropdown {
position: relative;
display: inline-block;
}
.dropdown-content {
display: none;
position: absolute;
background-color: #f9f9f9;
min-width: 160px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
padding: 12px 16px;
z-index: 1;
text-align: left;
}
.dropdown:hover .dropdown-content {
display: block;
}
.dropdown-option:hover {
color: #FF4500;
}
.dropdown-option-active {
color: #FF4500;
font-weight: lighter;
}
.dropdown-option {
color: #000000;
font-weight: lighter;
}
.dropdown-header {
color: #FFFFFF;
display: inline-flex;
}
.dropdown-caret {
width: 18px;
}
.dropdown-caret-path {
fill: #FFFFFF;
}
</style>
<title>Image similarity search with InfoGAN &#8212; Apache MXNet documentation</title>
<link rel="stylesheet" href="../../../../_static/basic.css" type="text/css" />
<link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" type="text/css" href="../../../../_static/mxnet.css" />
<link rel="stylesheet" href="../../../../_static/material-design-lite-1.3.0/material.blue-deep_orange.min.css" type="text/css" />
<link rel="stylesheet" href="../../../../_static/sphinx_materialdesign_theme.css" type="text/css" />
<link rel="stylesheet" href="../../../../_static/fontawesome/all.css" type="text/css" />
<link rel="stylesheet" href="../../../../_static/fonts.css" type="text/css" />
<link rel="stylesheet" href="../../../../_static/feedback.css" type="text/css" />
<script id="documentation_options" data-url_root="../../../../" src="../../../../_static/documentation_options.js"></script>
<script src="../../../../_static/jquery.js"></script>
<script src="../../../../_static/underscore.js"></script>
<script src="../../../../_static/doctools.js"></script>
<script src="../../../../_static/language_data.js"></script>
<script src="../../../../_static/matomo_analytics.js"></script>
<script src="../../../../_static/autodoc.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
<link rel="shortcut icon" href="../../../../_static/mxnet-icon.png"/>
<link rel="index" title="Index" href="../../../../genindex.html" />
<link rel="search" title="Search" href="../../../../search.html" />
<link rel="next" title="Handwritten Digit Recognition" href="mnist.html" />
<link rel="prev" title="Image Augmentation" href="image-augmentation.html" />
</head>
<body><header class="site-header" role="banner">
<div class="wrapper">
<a class="site-title" rel="author" href="/versions/1.9.1/"><img
src="../../../../_static/mxnet_logo.png" class="site-header-logo"></a>
<nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger"/>
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="trigger">
<a class="page-link" href="/versions/1.9.1/get_started">Get Started</a>
<a class="page-link" href="/versions/1.9.1/features">Features</a>
<a class="page-link" href="/versions/1.9.1/ecosystem">Ecosystem</a>
<a class="page-link page-current" href="/versions/1.9.1/api">Docs & Tutorials</a>
<a class="page-link" href="/versions/1.9.1/trusted_by">Trusted By</a>
<a class="page-link" href="https://github.com/apache/mxnet">GitHub</a>
<div class="dropdown" style="min-width:100px">
<span class="dropdown-header">Apache
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content" style="min-width:250px">
<a href="https://www.apache.org/foundation/">Apache Software Foundation</a>
<a href="https://incubator.apache.org/">Apache Incubator</a>
<a href="https://www.apache.org/licenses/">License</a>
<a href="/versions/1.9.1/api/faq/security.html">Security</a>
<a href="https://privacy.apache.org/policies/privacy-policy-public.html">Privacy</a>
<a href="https://www.apache.org/events/current-event">Events</a>
<a href="https://www.apache.org/foundation/sponsorship.html">Sponsorship</a>
<a href="https://www.apache.org/foundation/thanks.html">Thanks</a>
</div>
</div>
<div class="dropdown">
<span class="dropdown-header">1.9.1
<svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
</span>
<div class="dropdown-content">
<a class="dropdown-option" href="/">master</a><br>
<a class="dropdown-option-active" href="/versions/1.9.1/">1.9.1</a><br>
<a class="dropdown-option" href="/versions/1.8.0/">1.8.0</a><br>
<a class="dropdown-option" href="/versions/1.7.0/">1.7.0</a><br>
<a class="dropdown-option" href="/versions/1.6.0/">1.6.0</a><br>
<a class="dropdown-option" href="/versions/1.5.0/">1.5.0</a><br>
<a class="dropdown-option" href="/versions/1.4.1/">1.4.1</a><br>
<a class="dropdown-option" href="/versions/1.3.1/">1.3.1</a><br>
<a class="dropdown-option" href="/versions/1.2.1/">1.2.1</a><br>
<a class="dropdown-option" href="/versions/1.1.0/">1.1.0</a><br>
<a class="dropdown-option" href="/versions/1.0.0/">1.0.0</a><br>
<a class="dropdown-option" href="/versions/0.12.1/">0.12.1</a><br>
<a class="dropdown-option" href="/versions/0.11.0/">0.11.0</a>
</div>
</div>
</div>
</nav>
</div>
</header>
<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header mdl-layout--fixed-drawer"><header class="mdl-layout__header mdl-layout__header--waterfall ">
<div class="mdl-layout__header-row">
<nav class="mdl-navigation breadcrumb">
<a class="mdl-navigation__link" href="../../../index.html">Python Tutorials</a><i class="material-icons">navigate_next</i>
<a class="mdl-navigation__link" href="../../index.html">Packages</a><i class="material-icons">navigate_next</i>
<a class="mdl-navigation__link" href="../index.html">Gluon</a><i class="material-icons">navigate_next</i>
<a class="mdl-navigation__link" href="index.html">Image Tutorials</a><i class="material-icons">navigate_next</i>
<a class="mdl-navigation__link is-active">Image similarity search with InfoGAN</a>
</nav>
<div class="mdl-layout-spacer"></div>
<nav class="mdl-navigation">
<form class="form-inline pull-sm-right" action="../../../../search.html" method="get">
<div class="mdl-textfield mdl-js-textfield mdl-textfield--expandable mdl-textfield--floating-label mdl-textfield--align-right">
<label id="quick-search-icon" class="mdl-button mdl-js-button mdl-button--icon" for="waterfall-exp">
<i class="material-icons">search</i>
</label>
<div class="mdl-textfield__expandable-holder">
<input class="mdl-textfield__input" type="text" name="q" id="waterfall-exp" placeholder="Search" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</div>
</div>
<div class="mdl-tooltip" data-mdl-for="quick-search-icon">
Quick search
</div>
</form>
<a id="button-show-source"
class="mdl-button mdl-js-button mdl-button--icon"
href="../../../../_sources/tutorials/packages/gluon/image/info_gan.ipynb" rel="nofollow">
<i class="material-icons">code</i>
</a>
<div class="mdl-tooltip" data-mdl-for="button-show-source">
Show Source
</div>
</nav>
</div>
<div class="mdl-layout__header-row header-links">
<div class="mdl-layout-spacer"></div>
<nav class="mdl-navigation">
</nav>
</div>
</header><header class="mdl-layout__drawer">
<div class="globaltoc">
<span class="mdl-layout-title toc">Table Of Contents</span>
<nav class="mdl-navigation">
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../../../index.html">Python Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../../../getting-started/index.html">Getting Started</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/crash-course/index.html">Crash Course</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/1-ndarray.html">Manipulate data with <code class="docutils literal notranslate"><span class="pre">ndarray</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/2-nn.html">Create a neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/3-autograd.html">Automatic differentiation with <code class="docutils literal notranslate"><span class="pre">autograd</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/4-train.html">Train the neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/5-predict.html">Predict with a pre-trained model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/6-use_gpus.html">Use GPUs</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/to-mxnet/index.html">Moving to MXNet from Other Frameworks</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/to-mxnet/pytorch.html">PyTorch vs Apache MXNet</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/gluon_from_experiment_to_deployment.html">Gluon: from experiment to deployment</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/logistic_regression_explained.html">Logistic regression explained</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html">MNIST</a></li>
</ul>
</li>
<li class="toctree-l2 current"><a class="reference internal" href="../../index.html">Packages</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../../autograd/index.html">Automatic Differentiation</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html">Gluon</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="../blocks/index.html">Blocks</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom-layer.html">Custom Layers</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom_layer_beginners.html">Customer Layers (Beginners)</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/hybridize.html">Hybridize</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/init.html">Initialization</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/naming.html">Parameter and Block Naming</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/nn.html">Layers and Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/parameters.html">Parameter Management</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/save_load_params.html">Saving and Loading Gluon Models</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/activations/activations.html">Activation Blocks</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../data/index.html">Data Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Spatial-Augmentation">Spatial Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Color-Augmentation">Color Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Composed-Augmentations">Composed Augmentations</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html">Gluon <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s and <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-included-Datasets">Using own data with included <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-custom-Datasets">Using own data with custom <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Appendix:-Upgrading-from-Module-DataIter-to-Gluon-DataLoader">Appendix: Upgrading from Module <code class="docutils literal notranslate"><span class="pre">DataIter</span></code> to Gluon <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
</ul>
</li>
<li class="toctree-l4 current"><a class="reference internal" href="index.html">Image Tutorials</a><ul class="current">
<li class="toctree-l5"><a class="reference internal" href="image-augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5 current"><a class="current reference internal" href="#">Image similarity search with InfoGAN</a></li>
<li class="toctree-l5"><a class="reference internal" href="mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l5"><a class="reference internal" href="pretrained_models.html">Using pre-trained models in MXNet</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../loss/index.html">Losses</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../loss/custom-loss.html">Custom Loss Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/kl_divergence.html">Kullback-Leibler (KL) Divergence</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/loss.html">Loss functions</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../text/index.html">Text Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../text/gnmt.html">Google Neural Machine Translation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../text/transformer.html">Machine Translation with Transformer</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../training/index.html">Training</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../training/fit_api_tutorial.html">MXNet Gluon Fit API</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/trainer.html">Trainer</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/learning_rates/index.html">Learning Rates</a><ul>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_finder.html">Learning Rate Finder</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules.html">Learning Rate Schedules</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules_advanced.html">Advanced Learning Rate Schedules</a></li>
</ul>
</li>
<li class="toctree-l5"><a class="reference internal" href="../training/normalization/index.html">Normalization Blocks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../kvstore/index.html">KVStore</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../kvstore/kvstore.html">Distributed Key-Value Store</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../ndarray/index.html">NDArray</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/01-ndarray-intro.html">An Intro: Manipulate Data the MXNet Way with NDArray</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/02-ndarray-operations.html">NDArray Operations</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/03-ndarray-contexts.html">NDArray Contexts</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/gotchas_numpy_in_mxnet.html">Gotchas using NumPy in Apache MXNet</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/sparse/index.html">Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/csr.html">CSRNDArray - NDArray in Compressed Sparse Row Storage Format</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/row_sparse.html">RowSparseNDArray - NDArray for Sparse Gradient Updates</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train.html">Train a Linear Regression Model with Sparse Symbols</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train_gluon.html">Sparse NDArrays with Gluon</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../onnx/index.html">ONNX</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/fine_tuning_gluon.html">Fine-tuning an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/inference_on_onnx_model.html">Running inference on MXNet/Gluon from an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/super_resolution.html">Importing an ONNX model into MXNet</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/deploy/export/onnx.html">Export ONNX Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../optimizer/index.html">Optimizers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../viz/index.html">Visualization</a><ul>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/visualize_graph">Visualize networks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../performance/index.html">Performance</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/compression/index.html">Compression</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/compression/int8.html">Deploy with int-8</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/float16">Float16</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/gradient_compression">Gradient Compression</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/int8_inference.html">GluonCV with Quantized Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/backend/index.html">Accelerated Backend Tools</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/mkldnn/index.html">Intel MKL-DNN</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_quantization.html">Quantize with MKL-DNN backend</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_quantization.html#Improving-accuracy-with-Intel®-Neural-Compressor">Improving accuracy with Intel® Neural Compressor</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_readme.html">Install MXNet with MKL-DNN</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tensorrt/index.html">TensorRT</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/tensorrt/tensorrt.html">Optimizing Deep Learning Computation Graphs with TensorRT</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tvm.html">Use TVM</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/profiler.html">Profiling MXNet Models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/amp.html">Using AMP: Automatic Mixed Precision</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../deploy/index.html">Deployment</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/export/index.html">Export</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/export/onnx.html">Exporting to ONNX format</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/export_network.html">Export Gluon CV Models</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html">Save / Load Parameters</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/inference/index.html">Inference</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/cpp.html">Deploy into C++</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/image_classification_jetson.html">Image Classication using pretrained ResNet-50 model on Jetson module</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/scala.html">Deploy into a Java or Scala Environment</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/wine_detector.html">Real-time Object Detection with MXNet On The Raspberry Pi</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/run-on-aws/index.html">Run on AWS</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_ec2.html">Run on an EC2 Instance</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_sagemaker.html">Run on Amazon SageMaker</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/cloud.html">MXNet on the Cloud</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../extend/index.html">Extend</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/custom_layer.html">Custom Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/customop.html">Custom Numpy Operators</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/new_op">New Operator Creation</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/add_op_in_backend">New Operator in MXNet Backend</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../../../api/index.html">Python API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/ndarray/index.html">mxnet.ndarray</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/ndarray.html">ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/contrib/index.html">ndarray.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/image/index.html">ndarray.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/linalg/index.html">ndarray.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/op/index.html">ndarray.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/random/index.html">ndarray.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/register/index.html">ndarray.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/sparse/index.html">ndarray.sparse</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/utils/index.html">ndarray.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/gluon/index.html">mxnet.gluon</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/block.html">gluon.Block</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/hybrid_block.html">gluon.HybridBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/symbol_block.html">gluon.SymbolBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/constant.html">gluon.Constant</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter.html">gluon.Parameter</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter_dict.html">gluon.ParameterDict</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/trainer.html">gluon.Trainer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/contrib/index.html">gluon.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/data/index.html">gluon.data</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../../api/gluon/data/vision/index.html">data.vision</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/datasets/index.html">vision.datasets</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/transforms/index.html">vision.transforms</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/loss/index.html">gluon.loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/model_zoo/index.html">gluon.model_zoo.vision</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/nn/index.html">gluon.nn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/rnn/index.html">gluon.rnn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/utils/index.html">gluon.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/autograd/index.html">mxnet.autograd</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/initializer/index.html">mxnet.initializer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/optimizer/index.html">mxnet.optimizer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/lr_scheduler/index.html">mxnet.lr_scheduler</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/metric/index.html">mxnet.metric</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/kvstore/index.html">mxnet.kvstore</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/symbol/index.html">mxnet.symbol</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/symbol.html">symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/contrib/index.html">symbol.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/image/index.html">symbol.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/linalg/index.html">symbol.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/op/index.html">symbol.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/random/index.html">symbol.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/register/index.html">symbol.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/sparse/index.html">symbol.sparse</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/module/index.html">mxnet.module</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/contrib/index.html">mxnet.contrib</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/autograd/index.html">contrib.autograd</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/io/index.html">contrib.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/ndarray/index.html">contrib.ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/onnx/index.html">contrib.onnx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/quantization/index.html">contrib.quantization</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/symbol/index.html">contrib.symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorboard/index.html">contrib.tensorboard</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorrt/index.html">contrib.tensorrt</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/text/index.html">contrib.text</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/mxnet/index.html">mxnet</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/attribute/index.html">mxnet.attribute</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/base/index.html">mxnet.base</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/callback/index.html">mxnet.callback</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/context/index.html">mxnet.context</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/engine/index.html">mxnet.engine</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor/index.html">mxnet.executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor_manager/index.html">mxnet.executor_manager</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/image/index.html">mxnet.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/io/index.html">mxnet.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/kvstore_server/index.html">mxnet.kvstore_server</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/libinfo/index.html">mxnet.libinfo</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/log/index.html">mxnet.log</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/model/index.html">mxnet.model</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/monitor/index.html">mxnet.monitor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/name/index.html">mxnet.name</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/notebook/index.html">mxnet.notebook</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/operator/index.html">mxnet.operator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/profiler/index.html">mxnet.profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/random/index.html">mxnet.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/recordio/index.html">mxnet.recordio</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/registry/index.html">mxnet.registry</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/rtc/index.html">mxnet.rtc</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/runtime/index.html">mxnet.runtime</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/test_utils/index.html">mxnet.test_utils</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/torch/index.html">mxnet.torch</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/util/index.html">mxnet.util</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/visualization/index.html">mxnet.visualization</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</nav>
</div>
</header>
<main class="mdl-layout__content" tabIndex="0">
<script type="text/javascript" src="../../../../_static/sphinx_materialdesign_theme.js "></script>
<script type="text/javascript" src="../../../../_static/feedback.js"></script>
<header class="mdl-layout__drawer">
<div class="globaltoc">
<span class="mdl-layout-title toc">Table Of Contents</span>
<nav class="mdl-navigation">
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../../../index.html">Python Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../../../getting-started/index.html">Getting Started</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/crash-course/index.html">Crash Course</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/1-ndarray.html">Manipulate data with <code class="docutils literal notranslate"><span class="pre">ndarray</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/2-nn.html">Create a neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/3-autograd.html">Automatic differentiation with <code class="docutils literal notranslate"><span class="pre">autograd</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/4-train.html">Train the neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/5-predict.html">Predict with a pre-trained model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/6-use_gpus.html">Use GPUs</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/to-mxnet/index.html">Moving to MXNet from Other Frameworks</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/to-mxnet/pytorch.html">PyTorch vs Apache MXNet</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/gluon_from_experiment_to_deployment.html">Gluon: from experiment to deployment</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/logistic_regression_explained.html">Logistic regression explained</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html">MNIST</a></li>
</ul>
</li>
<li class="toctree-l2 current"><a class="reference internal" href="../../index.html">Packages</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../../autograd/index.html">Automatic Differentiation</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html">Gluon</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="../blocks/index.html">Blocks</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom-layer.html">Custom Layers</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom_layer_beginners.html">Customer Layers (Beginners)</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/hybridize.html">Hybridize</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/init.html">Initialization</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/naming.html">Parameter and Block Naming</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/nn.html">Layers and Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/parameters.html">Parameter Management</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/save_load_params.html">Saving and Loading Gluon Models</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/activations/activations.html">Activation Blocks</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../data/index.html">Data Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Spatial-Augmentation">Spatial Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Color-Augmentation">Color Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Composed-Augmentations">Composed Augmentations</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html">Gluon <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s and <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-included-Datasets">Using own data with included <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-custom-Datasets">Using own data with custom <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Appendix:-Upgrading-from-Module-DataIter-to-Gluon-DataLoader">Appendix: Upgrading from Module <code class="docutils literal notranslate"><span class="pre">DataIter</span></code> to Gluon <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
</ul>
</li>
<li class="toctree-l4 current"><a class="reference internal" href="index.html">Image Tutorials</a><ul class="current">
<li class="toctree-l5"><a class="reference internal" href="image-augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5 current"><a class="current reference internal" href="#">Image similarity search with InfoGAN</a></li>
<li class="toctree-l5"><a class="reference internal" href="mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l5"><a class="reference internal" href="pretrained_models.html">Using pre-trained models in MXNet</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../loss/index.html">Losses</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../loss/custom-loss.html">Custom Loss Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/kl_divergence.html">Kullback-Leibler (KL) Divergence</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/loss.html">Loss functions</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../text/index.html">Text Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../text/gnmt.html">Google Neural Machine Translation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../text/transformer.html">Machine Translation with Transformer</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../training/index.html">Training</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../training/fit_api_tutorial.html">MXNet Gluon Fit API</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/trainer.html">Trainer</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/learning_rates/index.html">Learning Rates</a><ul>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_finder.html">Learning Rate Finder</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules.html">Learning Rate Schedules</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules_advanced.html">Advanced Learning Rate Schedules</a></li>
</ul>
</li>
<li class="toctree-l5"><a class="reference internal" href="../training/normalization/index.html">Normalization Blocks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../kvstore/index.html">KVStore</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../kvstore/kvstore.html">Distributed Key-Value Store</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../ndarray/index.html">NDArray</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/01-ndarray-intro.html">An Intro: Manipulate Data the MXNet Way with NDArray</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/02-ndarray-operations.html">NDArray Operations</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/03-ndarray-contexts.html">NDArray Contexts</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/gotchas_numpy_in_mxnet.html">Gotchas using NumPy in Apache MXNet</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/sparse/index.html">Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/csr.html">CSRNDArray - NDArray in Compressed Sparse Row Storage Format</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/row_sparse.html">RowSparseNDArray - NDArray for Sparse Gradient Updates</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train.html">Train a Linear Regression Model with Sparse Symbols</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train_gluon.html">Sparse NDArrays with Gluon</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../onnx/index.html">ONNX</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/fine_tuning_gluon.html">Fine-tuning an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/inference_on_onnx_model.html">Running inference on MXNet/Gluon from an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/super_resolution.html">Importing an ONNX model into MXNet</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/deploy/export/onnx.html">Export ONNX Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../optimizer/index.html">Optimizers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../viz/index.html">Visualization</a><ul>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/visualize_graph">Visualize networks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../performance/index.html">Performance</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/compression/index.html">Compression</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/compression/int8.html">Deploy with int-8</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/float16">Float16</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/gradient_compression">Gradient Compression</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/int8_inference.html">GluonCV with Quantized Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/backend/index.html">Accelerated Backend Tools</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/mkldnn/index.html">Intel MKL-DNN</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_quantization.html">Quantize with MKL-DNN backend</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_quantization.html#Improving-accuracy-with-Intel®-Neural-Compressor">Improving accuracy with Intel® Neural Compressor</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_readme.html">Install MXNet with MKL-DNN</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tensorrt/index.html">TensorRT</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/tensorrt/tensorrt.html">Optimizing Deep Learning Computation Graphs with TensorRT</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tvm.html">Use TVM</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/profiler.html">Profiling MXNet Models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/amp.html">Using AMP: Automatic Mixed Precision</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../deploy/index.html">Deployment</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/export/index.html">Export</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/export/onnx.html">Exporting to ONNX format</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/export_network.html">Export Gluon CV Models</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html">Save / Load Parameters</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/inference/index.html">Inference</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/cpp.html">Deploy into C++</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/image_classification_jetson.html">Image Classication using pretrained ResNet-50 model on Jetson module</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/scala.html">Deploy into a Java or Scala Environment</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/wine_detector.html">Real-time Object Detection with MXNet On The Raspberry Pi</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/run-on-aws/index.html">Run on AWS</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_ec2.html">Run on an EC2 Instance</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_sagemaker.html">Run on Amazon SageMaker</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/cloud.html">MXNet on the Cloud</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../extend/index.html">Extend</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/custom_layer.html">Custom Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/customop.html">Custom Numpy Operators</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/new_op">New Operator Creation</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/add_op_in_backend">New Operator in MXNet Backend</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../../../api/index.html">Python API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/ndarray/index.html">mxnet.ndarray</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/ndarray.html">ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/contrib/index.html">ndarray.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/image/index.html">ndarray.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/linalg/index.html">ndarray.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/op/index.html">ndarray.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/random/index.html">ndarray.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/register/index.html">ndarray.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/sparse/index.html">ndarray.sparse</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/utils/index.html">ndarray.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/gluon/index.html">mxnet.gluon</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/block.html">gluon.Block</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/hybrid_block.html">gluon.HybridBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/symbol_block.html">gluon.SymbolBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/constant.html">gluon.Constant</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter.html">gluon.Parameter</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter_dict.html">gluon.ParameterDict</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/trainer.html">gluon.Trainer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/contrib/index.html">gluon.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/data/index.html">gluon.data</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../../api/gluon/data/vision/index.html">data.vision</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/datasets/index.html">vision.datasets</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/transforms/index.html">vision.transforms</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/loss/index.html">gluon.loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/model_zoo/index.html">gluon.model_zoo.vision</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/nn/index.html">gluon.nn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/rnn/index.html">gluon.rnn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/utils/index.html">gluon.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/autograd/index.html">mxnet.autograd</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/initializer/index.html">mxnet.initializer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/optimizer/index.html">mxnet.optimizer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/lr_scheduler/index.html">mxnet.lr_scheduler</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/metric/index.html">mxnet.metric</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/kvstore/index.html">mxnet.kvstore</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/symbol/index.html">mxnet.symbol</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/symbol.html">symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/contrib/index.html">symbol.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/image/index.html">symbol.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/linalg/index.html">symbol.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/op/index.html">symbol.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/random/index.html">symbol.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/register/index.html">symbol.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/sparse/index.html">symbol.sparse</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/module/index.html">mxnet.module</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/contrib/index.html">mxnet.contrib</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/autograd/index.html">contrib.autograd</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/io/index.html">contrib.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/ndarray/index.html">contrib.ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/onnx/index.html">contrib.onnx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/quantization/index.html">contrib.quantization</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/symbol/index.html">contrib.symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorboard/index.html">contrib.tensorboard</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorrt/index.html">contrib.tensorrt</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/text/index.html">contrib.text</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/mxnet/index.html">mxnet</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/attribute/index.html">mxnet.attribute</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/base/index.html">mxnet.base</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/callback/index.html">mxnet.callback</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/context/index.html">mxnet.context</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/engine/index.html">mxnet.engine</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor/index.html">mxnet.executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor_manager/index.html">mxnet.executor_manager</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/image/index.html">mxnet.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/io/index.html">mxnet.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/kvstore_server/index.html">mxnet.kvstore_server</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/libinfo/index.html">mxnet.libinfo</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/log/index.html">mxnet.log</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/model/index.html">mxnet.model</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/monitor/index.html">mxnet.monitor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/name/index.html">mxnet.name</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/notebook/index.html">mxnet.notebook</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/operator/index.html">mxnet.operator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/profiler/index.html">mxnet.profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/random/index.html">mxnet.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/recordio/index.html">mxnet.recordio</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/registry/index.html">mxnet.registry</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/rtc/index.html">mxnet.rtc</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/runtime/index.html">mxnet.runtime</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/test_utils/index.html">mxnet.test_utils</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/torch/index.html">mxnet.torch</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/util/index.html">mxnet.util</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/visualization/index.html">mxnet.visualization</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</nav>
</div>
</header>
<div class="document">
<div class="page-content" role="main">
<!--- Licensed to the Apache Software Foundation (ASF) under one --><!--- or more contributor license agreements. See the NOTICE file --><!--- distributed with this work for additional information --><!--- regarding copyright ownership. The ASF licenses this file --><!--- to you under the Apache License, Version 2.0 (the --><!--- "License"); you may not use this file except in compliance --><!--- with the License. You may obtain a copy of the License at --><!--- http://www.apache.org/licenses/LICENSE-2.0 --><!--- Unless required by applicable law or agreed to in writing, --><!--- software distributed under the License is distributed on an --><!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --><!--- KIND, either express or implied. See the License for the --><!--- specific language governing permissions and limitations --><!--- under the License. --><div class="section" id="Image-similarity-search-with-InfoGAN">
<h1>Image similarity search with InfoGAN<a class="headerlink" href="#Image-similarity-search-with-InfoGAN" title="Permalink to this headline"></a></h1>
<p>This notebook shows how to implement an InfoGAN based on Gluon. InfoGAN is an extension of GANs, where the generator input is split in 2 parts: random noise and a latent code (see <a class="reference external" href="https://arxiv.org/pdf/1606.03657.pdf">InfoGAN Paper</a>). The codes are made meaningful by maximizing the mutual information between code and generator output. InfoGAN learns a disentangled representation in a completely unsupervised manner. It can be used for many applications such as image similarity search. This
notebook uses the DCGAN example from the <a class="reference external" href="https://gluon.mxnet.io/chapter14_generative-adversarial-networks/dcgan.html">Straight Dope Book</a> and extends it to create an InfoGAN.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>
<span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">datetime</span>
<span class="kn">import</span> <span class="nn">logging</span>
<span class="kn">import</span> <span class="nn">multiprocessing</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">tarfile</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">mxboard</span> <span class="kn">import</span> <span class="n">SummaryWriter</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mxnet</span> <span class="kn">import</span> <span class="n">gluon</span>
<span class="kn">from</span> <span class="nn">mxnet</span> <span class="kn">import</span> <span class="n">ndarray</span> <span class="k">as</span> <span class="n">nd</span>
<span class="kn">from</span> <span class="nn">mxnet.gluon</span> <span class="kn">import</span> <span class="n">nn</span><span class="p">,</span> <span class="n">utils</span>
<span class="kn">from</span> <span class="nn">mxnet</span> <span class="kn">import</span> <span class="n">autograd</span>
</pre></div>
</div>
<p>The latent code vector can contain several variables, which can be categorical and/or continuous. We set <code class="docutils literal notranslate"><span class="pre">n_continuous</span></code> to 2 and <code class="docutils literal notranslate"><span class="pre">n_categories</span></code> to 10.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">z_dim</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">n_continuous</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">n_categories</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">()</span> <span class="k">if</span> <span class="n">mx</span><span class="o">.</span><span class="n">context</span><span class="o">.</span><span class="n">num_gpus</span><span class="p">()</span> <span class="k">else</span> <span class="n">mx</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
</pre></div>
</div>
<p>Some functions to load and normalize images.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">lfw_url</span> <span class="o">=</span> <span class="s1">&#39;http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz&#39;</span>
<span class="n">data_path</span> <span class="o">=</span> <span class="s1">&#39;lfw_dataset&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">data_path</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">data_path</span><span class="p">)</span>
<span class="n">data_file</span> <span class="o">=</span> <span class="n">utils</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">lfw_url</span><span class="p">)</span>
<span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">data_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">tar</span><span class="p">:</span>
<span class="n">tar</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="n">path</span><span class="o">=</span><span class="n">data_path</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">transform</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">height</span><span class="o">=</span><span class="mi">64</span><span class="p">):</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">imresize</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span><span class="o">/</span><span class="mf">127.5</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="n">data</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,)</span> <span class="o">+</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_files</span><span class="p">(</span><span class="n">data_dir</span><span class="p">):</span>
<span class="n">images</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">filenames</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">path</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">fnames</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">walk</span><span class="p">(</span><span class="n">data_dir</span><span class="p">):</span>
<span class="k">for</span> <span class="n">fname</span> <span class="ow">in</span> <span class="n">fnames</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">fname</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">&#39;.jpg&#39;</span><span class="p">):</span>
<span class="k">continue</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">fname</span><span class="p">)</span>
<span class="n">img_arr</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">img_arr</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">img_arr</span><span class="p">)</span>
<span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_arr</span><span class="p">)</span>
<span class="n">filenames</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">&quot;/&quot;</span> <span class="o">+</span> <span class="n">fname</span><span class="p">)</span>
<span class="k">return</span> <span class="n">images</span><span class="p">,</span> <span class="n">filenames</span>
</pre></div>
</div>
<p>Load the dataset <code class="docutils literal notranslate"><span class="pre">lfw_dataset</span></code> which contains images of celebrities.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">data_dir</span> <span class="o">=</span> <span class="s1">&#39;lfw_dataset&#39;</span>
<span class="n">images</span><span class="p">,</span> <span class="n">filenames</span> <span class="o">=</span> <span class="n">get_files</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
<span class="n">split</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span><span class="o">*</span><span class="mf">0.8</span><span class="p">)</span>
<span class="n">test_images</span> <span class="o">=</span> <span class="n">images</span><span class="p">[</span><span class="n">split</span><span class="p">:]</span>
<span class="n">test_filenames</span> <span class="o">=</span> <span class="n">filenames</span><span class="p">[</span><span class="n">split</span><span class="p">:]</span>
<span class="n">train_images</span> <span class="o">=</span> <span class="n">images</span><span class="p">[:</span><span class="n">split</span><span class="p">]</span>
<span class="n">train_filenames</span> <span class="o">=</span> <span class="n">filenames</span><span class="p">[:</span><span class="n">split</span><span class="p">]</span>
<span class="n">train_data</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ArrayDataset</span><span class="p">(</span><span class="n">nd</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">train_images</span><span class="p">))</span>
<span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">last_batch</span><span class="o">=</span><span class="s1">&#39;rollover&#39;</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="n">multiprocessing</span><span class="o">.</span><span class="n">cpu_count</span><span class="p">()</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</div>
<div class="section" id="Generator">
<h2>Generator<a class="headerlink" href="#Generator" title="Permalink to this headline"></a></h2>
<p>Define the Generator model. Architecture is taken from the DCGAN implementation in <a class="reference external" href="https://gluon.mxnet.io/chapter14_generative-adversarial-networks/dcgan.html">Straight Dope Book</a>. The Generator consist of 4 layers where each layer involves a strided convolution, batch normalization, and rectified nonlinearity. It takes as input random noise and the latent code and produces an <code class="docutils literal notranslate"><span class="pre">(64,64,3)</span></code> output image.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">gluon</span><span class="o">.</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Generator</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">name_scope</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prev</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prev</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(),</span> <span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s1">&#39;tanh&#39;</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">G</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="Discriminator">
<h2>Discriminator<a class="headerlink" href="#Discriminator" title="Permalink to this headline"></a></h2>
<p>Define the Discriminator and Q model. The Q model shares many layers with the Discriminator. Its task is to estimate the code <code class="docutils literal notranslate"><span class="pre">c</span></code> for a given fake image. It is used to maximize the lower bound to the mutual information.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">gluon</span><span class="o">.</span><span class="n">HybridBlock</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Discriminator</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">name_scope</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(),</span> <span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">feat</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">feat</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(),</span> <span class="n">nn</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">category_prob</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">n_categories</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">continuous_mean</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">n_continuous</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">Q</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">HybridSequential</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">Q</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">feat</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">category_prob</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">continuous_mean</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">hybrid_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">feat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feat</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">category_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">category_prob</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">continuous_mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">continuous_mean</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="k">return</span> <span class="n">prob</span><span class="p">,</span> <span class="n">category_prob</span><span class="p">,</span> <span class="n">continuous_mean</span>
</pre></div>
</div>
<p>The InfoGAN has the following layout.</p>
<p>Discriminator and Generator are the same as in the DCGAN example. On top of the Disciminator is the Q model, which is estimating the code <code class="docutils literal notranslate"><span class="pre">c</span></code> for given fake images. The Generator’s input is random noise and the latent code <code class="docutils literal notranslate"><span class="pre">c</span></code>.</p>
</div>
<div class="section" id="Training-Loop">
<h2>Training Loop<a class="headerlink" href="#Training-Loop" title="Permalink to this headline"></a></h2>
<p>Initialize Generator and Discriminator and define correspoing trainer function.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">generator</span> <span class="o">=</span> <span class="n">Generator</span><span class="p">()</span>
<span class="n">generator</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
<span class="n">generator</span><span class="o">.</span><span class="n">initialize</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.002</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">discriminator</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">()</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">initialize</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.002</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.0001</span>
<span class="n">beta</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="n">g_trainer</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">generator</span><span class="o">.</span><span class="n">collect_params</span><span class="p">(),</span> <span class="s1">&#39;adam&#39;</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">lr</span><span class="p">,</span> <span class="s1">&#39;beta1&#39;</span><span class="p">:</span> <span class="n">beta</span><span class="p">})</span>
<span class="n">d_trainer</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">discriminator</span><span class="o">.</span><span class="n">collect_params</span><span class="p">(),</span> <span class="s1">&#39;adam&#39;</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">lr</span><span class="p">,</span> <span class="s1">&#39;beta1&#39;</span><span class="p">:</span> <span class="n">beta</span><span class="p">})</span>
<span class="n">q_trainer</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">discriminator</span><span class="o">.</span><span class="n">Q</span><span class="o">.</span><span class="n">collect_params</span><span class="p">(),</span> <span class="s1">&#39;adam&#39;</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">lr</span><span class="p">,</span> <span class="s1">&#39;beta1&#39;</span><span class="p">:</span> <span class="n">beta</span><span class="p">})</span>
</pre></div>
</div>
<p>Create vectors with real (=1) and fake labels (=0).</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">real_label</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">fake_label</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,),</span><span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
</pre></div>
</div>
<p>Load a pretrained model.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="s1">&#39;infogan_d_latest.params&#39;</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="s1">&#39;infogan_g_latest.params&#39;</span><span class="p">):</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">load_parameters</span><span class="p">(</span><span class="s1">&#39;infogan_d_latest.params&#39;</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">allow_missing</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ignore_extra</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">generator</span><span class="o">.</span><span class="n">load_parameters</span><span class="p">(</span><span class="s1">&#39;infogan_g_latest.params&#39;</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">allow_missing</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ignore_extra</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<p>There are 2 differences between InfoGAN and DCGAN: the extra latent code and the Q network to estimate the code. The latent code is part of the Generator input and it contains mutliple variables (continuous, categorical) that can represent different distributions. In order to make sure that the Generator uses the latent code, mutual information is introduced into the GAN loss term. Mutual information measures how much X is known given Y or vice versa. It is defined as:</p>
<img alt="gif" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/info_gan/entropy.gif" />
<p>The InfoGAN loss is:</p>
<img alt="gif" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/info_gan/loss.gif" />
<p>where <code class="docutils literal notranslate"><span class="pre">V(D,G)</span></code> is the GAN loss and the mutual information <code class="docutils literal notranslate"><span class="pre">I(c,</span> <span class="pre">G(z,</span> <span class="pre">c))</span></code> goes in as regularization. The goal is to reach high mutual information, in order to learn meaningful codes for the data.</p>
<p>Define the loss functions. <code class="docutils literal notranslate"><span class="pre">SoftmaxCrossEntropyLoss</span></code> for the categorical code, <code class="docutils literal notranslate"><span class="pre">L2Loss</span></code> for the continious code and <code class="docutils literal notranslate"><span class="pre">SigmoidBinaryCrossEntropyLoss</span></code> for the normal GAN loss.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">loss1</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">SigmoidBinaryCrossEntropyLoss</span><span class="p">()</span>
<span class="n">loss2</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">L2Loss</span><span class="p">()</span>
<span class="n">loss3</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">SoftmaxCrossEntropyLoss</span><span class="p">()</span>
</pre></div>
</div>
<p>This function samples <code class="docutils literal notranslate"><span class="pre">c</span></code>, <code class="docutils literal notranslate"><span class="pre">z</span></code>, and concatenates them to create the generator input.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">create_generator_input</span><span class="p">():</span>
<span class="c1">#create random noise</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">random_normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">z_dim</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">label</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">n_categories</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">))</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">c1</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="n">n_categories</span><span class="p">)</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">c2</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_continuous</span><span class="p">))</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="c1"># concatenate random noise with c which will be the input of the generator</span>
<span class="k">return</span> <span class="n">nd</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">label</span><span class="p">,</span> <span class="n">c2</span>
</pre></div>
</div>
<p>Define the training loop. 1. The discriminator receives <code class="docutils literal notranslate"><span class="pre">real_data</span></code> and <code class="docutils literal notranslate"><span class="pre">loss1</span></code> measures how many real images have been identified as real 2. The discriminator receives <code class="docutils literal notranslate"><span class="pre">fake_image</span></code> from the Generator and <code class="docutils literal notranslate"><span class="pre">loss1</span></code> measures how many fake images have been identified as fake 3. Update Discriminator. Currently, it is updated every second iteration in order to avoid that the Discriminator becomes too strong. You may want to change that. 4. The updated discriminator receives <code class="docutils literal notranslate"><span class="pre">fake_image</span></code> and
<code class="docutils literal notranslate"><span class="pre">loss1</span></code> measures how many fake images have been been identified as real, <code class="docutils literal notranslate"><span class="pre">loss2</span></code> measures the difference between the sampled continuous latent code <code class="docutils literal notranslate"><span class="pre">c</span></code> and the output of the Q model and <code class="docutils literal notranslate"><span class="pre">loss3</span></code> measures the difference between the sampled categorical latent code <code class="docutils literal notranslate"><span class="pre">c</span></code> and the output of the Q model. 4. Update Generator and Q</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">with</span> <span class="n">SummaryWriter</span><span class="p">(</span><span class="n">logdir</span><span class="o">=</span><span class="s1">&#39;./logs/&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">sw</span><span class="p">:</span>
<span class="n">epochs</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">counter</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Epoch&quot;</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">starttime</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">d_error_epoch</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">g_error_epoch</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_dataloader</span><span class="p">):</span>
<span class="c1">#get real data and generator input</span>
<span class="n">real_data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">g_input</span><span class="p">,</span> <span class="n">label</span><span class="p">,</span> <span class="n">c2</span> <span class="o">=</span> <span class="n">create_generator_input</span><span class="p">()</span>
<span class="c1">#Update discriminator: Input real data and fake data</span>
<span class="k">with</span> <span class="n">autograd</span><span class="o">.</span><span class="n">record</span><span class="p">():</span>
<span class="n">output_real</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">_</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">real_data</span><span class="p">)</span>
<span class="n">d_error_real</span> <span class="o">=</span> <span class="n">loss1</span><span class="p">(</span><span class="n">output_real</span><span class="p">,</span> <span class="n">real_label</span><span class="p">)</span>
<span class="c1"># create fake image and input it to discriminator</span>
<span class="n">fake_image</span> <span class="o">=</span> <span class="n">generator</span><span class="p">(</span><span class="n">g_input</span><span class="p">)</span>
<span class="n">output_fake</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">_</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">fake_image</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="n">d_error_fake</span> <span class="o">=</span> <span class="n">loss1</span><span class="p">(</span><span class="n">output_fake</span><span class="p">,</span> <span class="n">fake_label</span><span class="p">)</span>
<span class="c1"># total discriminator error</span>
<span class="n">d_error</span> <span class="o">=</span> <span class="n">d_error_real</span> <span class="o">+</span> <span class="n">d_error_fake</span>
<span class="n">d_error_epoch</span> <span class="o">+=</span> <span class="n">d_error</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="c1">#Update D every second iteration</span>
<span class="k">if</span> <span class="p">(</span><span class="n">counter</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">d_error</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">d_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="c1">#Update generator: Input random noise and latent code vector</span>
<span class="k">with</span> <span class="n">autograd</span><span class="o">.</span><span class="n">record</span><span class="p">():</span>
<span class="n">fake_image</span> <span class="o">=</span> <span class="n">generator</span><span class="p">(</span><span class="n">g_input</span><span class="p">)</span>
<span class="n">output_fake</span><span class="p">,</span> <span class="n">category_prob</span><span class="p">,</span> <span class="n">continuous_mean</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">fake_image</span><span class="p">)</span>
<span class="n">g_error</span> <span class="o">=</span> <span class="n">loss1</span><span class="p">(</span><span class="n">output_fake</span><span class="p">,</span> <span class="n">real_label</span><span class="p">)</span> <span class="o">+</span> <span class="n">loss3</span><span class="p">(</span><span class="n">category_prob</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="o">+</span> <span class="n">loss2</span><span class="p">(</span><span class="n">c2</span><span class="p">,</span> <span class="n">continuous_mean</span><span class="p">)</span>
<span class="n">g_error</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">g_error_epoch</span> <span class="o">+=</span> <span class="n">g_error</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="n">g_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="n">q_trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="c1"># logging</span>
<span class="k">if</span> <span class="n">idx</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">count</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;speed: </span><span class="si">{}</span><span class="s1"> samples/s&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">/</span> <span class="p">(</span><span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">starttime</span><span class="p">)))</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;discriminator loss = </span><span class="si">%f</span><span class="s1">, generator loss = </span><span class="si">%f</span><span class="s1"> at iter </span><span class="si">%d</span><span class="s1"> epoch </span><span class="si">%d</span><span class="s1">&#39;</span>
<span class="o">%</span><span class="p">(</span><span class="n">d_error_epoch</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span><span class="o">/</span><span class="n">count</span><span class="p">,</span><span class="n">g_error_epoch</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span><span class="o">/</span><span class="n">count</span><span class="p">,</span> <span class="n">count</span><span class="p">,</span> <span class="n">epoch</span><span class="p">))</span>
<span class="n">g_input</span><span class="p">,</span><span class="n">_</span><span class="p">,</span><span class="n">_</span> <span class="o">=</span> <span class="n">create_generator_input</span><span class="p">()</span>
<span class="c1"># create some fake image for logging in MXBoard</span>
<span class="n">fake_image</span> <span class="o">=</span> <span class="n">generator</span><span class="p">(</span><span class="n">g_input</span><span class="p">)</span>
<span class="n">sw</span><span class="o">.</span><span class="n">add_scalar</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;Loss_D&#39;</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;test&#39;</span><span class="p">:</span><span class="n">d_error_epoch</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span><span class="o">/</span><span class="n">count</span><span class="p">},</span> <span class="n">global_step</span><span class="o">=</span><span class="n">counter</span><span class="p">)</span>
<span class="n">sw</span><span class="o">.</span><span class="n">add_scalar</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;Loss_G&#39;</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;test&#39;</span><span class="p">:</span><span class="n">d_error_epoch</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span><span class="o">/</span><span class="n">count</span><span class="p">},</span> <span class="n">global_step</span><span class="o">=</span><span class="n">counter</span><span class="p">)</span>
<span class="n">sw</span><span class="o">.</span><span class="n">add_image</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;data_image&#39;</span><span class="p">,</span> <span class="n">image</span><span class="o">=</span><span class="p">((</span><span class="n">fake_image</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="mf">127.5</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">counter</span><span class="p">)</span>
<span class="n">sw</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">save_parameters</span><span class="p">(</span><span class="s2">&quot;infogan_d_latest.params&quot;</span><span class="p">)</span>
<span class="n">generator</span><span class="o">.</span><span class="n">save_parameters</span><span class="p">(</span><span class="s2">&quot;infogan_g_latest.params&quot;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="Image-similarity">
<h2>Image similarity<a class="headerlink" href="#Image-similarity" title="Permalink to this headline"></a></h2>
<p>Once the InfoGAN is trained, we can use the Discriminator to do an image similarity search. The idea is that the network learned meaningful features from the images based on the mutual information e.g. pose of people in an image.</p>
<p>Load the trained discriminator and retrieve one of its last layers.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">discriminator</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">()</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">load_parameters</span><span class="p">(</span><span class="s2">&quot;infogan_d_latest.params&quot;</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">,</span> <span class="n">ignore_extra</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">discriminator</span> <span class="o">=</span> <span class="n">discriminator</span><span class="o">.</span><span class="n">D</span><span class="p">[:</span><span class="mi">11</span><span class="p">]</span>
<span class="nb">print</span> <span class="p">(</span><span class="n">discriminator</span><span class="p">)</span>
<span class="n">discriminator</span><span class="o">.</span><span class="n">hybridize</span><span class="p">()</span>
</pre></div>
</div>
<p>Nearest neighbor function, which takes a matrix of features and an input feature vector. It returns the 3 closest features.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_knn</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">input_vector</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">3</span><span class="p">):</span>
<span class="n">dist</span> <span class="o">=</span> <span class="p">(</span><span class="n">nd</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">features</span> <span class="o">-</span> <span class="n">input_vector</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span><span class="o">/</span><span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span><span class="o">.</span><span class="n">argsort</span><span class="p">()[:</span><span class="n">k</span><span class="p">]</span>
<span class="k">return</span> <span class="p">[(</span><span class="n">index</span><span class="p">,</span> <span class="n">dist</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="o">.</span><span class="n">asscalar</span><span class="p">())</span> <span class="k">for</span> <span class="n">index</span> <span class="ow">in</span> <span class="n">indices</span><span class="p">]</span>
</pre></div>
</div>
<p>A helper function to visualize image data.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">visualize</span><span class="p">(</span><span class="n">img_array</span><span class="p">):</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(((</span><span class="n">img_array</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="mf">127.5</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
</pre></div>
</div>
<p>Take some images from the test data, obtain its feature vector from <code class="docutils literal notranslate"><span class="pre">discriminator.D[:11]</span></code> and plot images of the corresponding closest vectors in the feature space.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">feature_size</span> <span class="o">=</span> <span class="mi">8192</span>
<span class="n">features</span> <span class="o">=</span> <span class="n">nd</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">test_images</span><span class="p">),</span> <span class="n">feature_size</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">image</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_images</span><span class="p">):</span>
<span class="n">feature</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">))</span>
<span class="n">feature</span> <span class="o">=</span> <span class="n">feature</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">feature_size</span><span class="p">,)</span>
<span class="n">features</span><span class="p">[</span><span class="n">idx</span><span class="p">,:]</span> <span class="o">=</span> <span class="n">feature</span><span class="o">.</span><span class="n">copyto</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
<span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">test_images</span><span class="p">[:</span><span class="mi">100</span><span class="p">]:</span>
<span class="n">feature</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">))</span>
<span class="n">feature</span> <span class="o">=</span> <span class="n">feature</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">feature_size</span><span class="p">,))</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="mi">64</span><span class="p">))</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">get_knn</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">feature</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span><span class="mi">12</span><span class="p">))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">10</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span>
<span class="n">visualize</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">9</span><span class="p">):</span>
<span class="k">if</span> <span class="n">indices</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.5</span><span class="p">:</span>
<span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">10</span><span class="p">,</span><span class="n">i</span><span class="p">)</span>
<span class="n">sim</span> <span class="o">=</span> <span class="n">test_images</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="mi">64</span><span class="p">)</span>
<span class="n">visualize</span><span class="p">(</span><span class="n">sim</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">clf</span><span class="p">()</span>
</pre></div>
</div>
<p><img alt="png" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/info_gan/output.png" /></p>
</div>
<div class="section" id="How-the-Generator-learns">
<h2>How the Generator learns<a class="headerlink" href="#How-the-Generator-learns" title="Permalink to this headline"></a></h2>
<p>We trained the Generator for a couple of epochs and stored a couple of fake images per epoch. Check the video. <img alt="alt text" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/info_gan/infogan.gif" /></p>
<p>The following function computes the TSNE on the feature matrix and stores the result in a json-file. This file can be loaded with <a class="reference external" href="https://ml4a.github.io/guides/ImageTSNEViewer/">TSNEViewer</a></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">json</span>
<span class="kn">from</span> <span class="nn">sklearn.manifold</span> <span class="kn">import</span> <span class="n">TSNE</span>
<span class="kn">from</span> <span class="nn">scipy.spatial</span> <span class="kn">import</span> <span class="n">distance</span>
<span class="n">tsne</span> <span class="o">=</span> <span class="n">TSNE</span><span class="p">(</span><span class="n">n_components</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mi">150</span><span class="p">,</span> <span class="n">perplexity</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">features</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">())</span>
<span class="c1"># save data to json</span>
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">counter</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">f</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_filenames</span><span class="p">):</span>
<span class="n">point</span> <span class="o">=</span> <span class="p">[</span><span class="nb">float</span><span class="p">((</span><span class="n">tsne</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">k</span><span class="p">]</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">tsne</span><span class="p">[:,</span><span class="n">k</span><span class="p">]))</span><span class="o">/</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">tsne</span><span class="p">[:,</span><span class="n">k</span><span class="p">])</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">tsne</span><span class="p">[:,</span><span class="n">k</span><span class="p">])))</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="p">]</span>
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&quot;path&quot;</span><span class="p">:</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getcwd</span><span class="p">(),</span><span class="n">f</span><span class="p">)),</span> <span class="s2">&quot;point&quot;</span><span class="p">:</span> <span class="n">point</span><span class="p">})</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s2">&quot;imagetsne.json&quot;</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">outfile</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">outfile</span><span class="p">)</span>
</pre></div>
</div>
<p>Load the file with TSNEViewer. You can now inspect whether similiar looking images are grouped nearby or not.</p>
<!-- INSERT SOURCE DOWNLOAD BUTTONS --></div>
</div>
<hr class="feedback-hr-top" />
<div class="feedback-container">
<div class="feedback-question">Did this page help you?</div>
<div class="feedback-answer-container">
<div class="feedback-answer yes-link" data-response="yes">Yes</div>
<div class="feedback-answer no-link" data-response="no">No</div>
</div>
<div class="feedback-thank-you">Thanks for your feedback!</div>
</div>
<hr class="feedback-hr-bottom" />
</div>
<div class="side-doc-outline">
<div class="side-doc-outline--content">
<div class="localtoc">
<p class="caption">
<span class="caption-text">Table Of Contents</span>
</p>
<ul>
<li><a class="reference internal" href="#">Image similarity search with InfoGAN</a><ul>
<li><a class="reference internal" href="#Generator">Generator</a></li>
<li><a class="reference internal" href="#Discriminator">Discriminator</a></li>
<li><a class="reference internal" href="#Training-Loop">Training Loop</a></li>
<li><a class="reference internal" href="#Image-similarity">Image similarity</a></li>
<li><a class="reference internal" href="#How-the-Generator-learns">How the Generator learns</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
<div class="clearer"></div>
</div><div class="pagenation">
<a id="button-prev" href="image-augmentation.html" class="mdl-button mdl-js-button mdl-js-ripple-effect mdl-button--colored" role="botton" accesskey="P">
<i class="pagenation-arrow-L fas fa-arrow-left fa-lg"></i>
<div class="pagenation-text">
<span class="pagenation-direction">Previous</span>
<div>Image Augmentation</div>
</div>
</a>
<a id="button-next" href="mnist.html" class="mdl-button mdl-js-button mdl-js-ripple-effect mdl-button--colored" role="botton" accesskey="N">
<i class="pagenation-arrow-R fas fa-arrow-right fa-lg"></i>
<div class="pagenation-text">
<span class="pagenation-direction">Next</span>
<div>Handwritten Digit Recognition</div>
</div>
</a>
</div>
<footer class="site-footer h-card">
<div class="wrapper">
<div class="row">
<div class="col-4">
<h4 class="footer-category-title">Resources</h4>
<ul class="contact-list">
<li><a class="u-email" href="mailto:dev@mxnet.apache.org">Dev list</a></li>
<li><a class="u-email" href="mailto:user@mxnet.apache.org">User mailing list</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
<li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
<li><a href="https://github.com/apache/mxnet/labels/Roadmap">Github Roadmap</a></li>
<li><a href="https://medium.com/apache-mxnet">Blog</a></li>
<li><a href="https://discuss.mxnet.io">Forum</a></li>
<li><a href="/community/contribute">Contribute</a></li>
</ul>
</div>
<div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/mxnet"><svg class="svg-icon"><use xlink:href="../../../../_static/minima-social-icons.svg#github"></use></svg> <span class="username">apache/mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="../../../../_static/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="../../../../_static/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>
<div class="col-4 footer-text">
<p>A flexible and efficient library for deep learning.</p>
</div>
</div>
</div>
</footer>
<footer class="site-footer2">
<div class="wrapper">
<div class="row">
<div class="col-3">
<img src="../../../../_static/apache_incubator_logo.png" class="footer-logo col-2">
</div>
<div class="footer-bottom-warning col-9">
<p>Apache MXNet is an effort undergoing incubation at <a href="http://www.apache.org/">The Apache Software Foundation</a> (ASF), <span style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required
of all newly accepted projects until a further review indicates that the infrastructure,
communications, and decision making process have stabilized in a manner consistent with other
successful ASF projects. While incubation status is not necessarily a reflection of the completeness
or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
</p><p>"Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache
feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
Apache Software Foundation."</p>
</div>
</div>
</div>
</footer>
</body>
</html>