<!DOCTYPE html>

<html xmlns="http://www.w3.org/1999/xhtml">
  <head>
    <meta charset="utf-8" />
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
    <meta http-equiv="x-ua-compatible" content="ie=edge">
    <style>
    .dropdown {
        position: relative;
        display: inline-block;
    }

    .dropdown-content {
        display: none;
        position: absolute;
        background-color: #f9f9f9;
        min-width: 160px;
        box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
        padding: 12px 16px;
        z-index: 1;
        text-align: left;
    }

    .dropdown:hover .dropdown-content {
        display: block;
    }

    .dropdown-option:hover {
        color: #FF4500;
    }

    .dropdown-option-active {
        color: #FF4500;
        font-weight: lighter;
    }

    .dropdown-option {
        color: #000000;
        font-weight: lighter;
    }

    .dropdown-header {
        color: #FFFFFF;
        display: inline-flex;
    }

    .dropdown-caret {
        width: 18px;
    }

    .dropdown-caret-path {
        fill: #FFFFFF;
    }
    </style>
    
    <title>Google Neural Machine Translation &#8212; Apache MXNet  documentation</title>

    <link rel="stylesheet" href="../../../../_static/basic.css" type="text/css" />
    <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../../../../_static/mxnet.css" />
    <link rel="stylesheet" href="../../../../_static/material-design-lite-1.3.0/material.blue-deep_orange.min.css" type="text/css" />
    <link rel="stylesheet" href="../../../../_static/sphinx_materialdesign_theme.css" type="text/css" />
    <link rel="stylesheet" href="../../../../_static/fontawesome/all.css" type="text/css" />
    <link rel="stylesheet" href="../../../../_static/fonts.css" type="text/css" />
    <link rel="stylesheet" href="../../../../_static/feedback.css" type="text/css" />
    <script id="documentation_options" data-url_root="../../../../" src="../../../../_static/documentation_options.js"></script>
    <script src="../../../../_static/jquery.js"></script>
    <script src="../../../../_static/underscore.js"></script>
    <script src="../../../../_static/doctools.js"></script>
    <script src="../../../../_static/language_data.js"></script>
    <script src="../../../../_static/google_analytics.js"></script>
    <script src="../../../../_static/autodoc.js"></script>
    <script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
    <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    <script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
    <link rel="shortcut icon" href="../../../../_static/mxnet-icon.png"/>
    <link rel="index" title="Index" href="../../../../genindex.html" />
    <link rel="search" title="Search" href="../../../../search.html" />
    <link rel="next" title="Machine Translation with Transformer" href="transformer.html" />
    <link rel="prev" title="Text Tutorials" href="index.html" /> 
  </head>
<body><header class="site-header" role="banner">
  <div class="wrapper">
      <a class="site-title" rel="author" href="/versions/1.7/"><img
            src="../../../../_static/mxnet_logo.png" class="site-header-logo"></a>
    <nav class="site-nav">
      <input type="checkbox" id="nav-trigger" class="nav-trigger"/>
      <label for="nav-trigger">
          <span class="menu-icon">
            <svg viewBox="0 0 18 15" width="18px" height="15px">
              <path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
            </svg>
          </span>
      </label>

      <div class="trigger">
        <a class="page-link" href="/versions/1.7/get_started">Get Started</a>
        <a class="page-link" href="/versions/1.7/blog">Blog</a>
        <a class="page-link" href="/versions/1.7/features">Features</a>
        <a class="page-link" href="/versions/1.7/ecosystem">Ecosystem</a>
        <a class="page-link page-current" href="/versions/1.7/api">Docs & Tutorials</a>
        <a class="page-link" href="https://github.com/apache/incubator-mxnet">GitHub</a>
        <div class="dropdown">
          <span class="dropdown-header">1.7
            <svg class="dropdown-caret" viewBox="0 0 32 32" class="icon icon-caret-bottom" aria-hidden="true"><path class="dropdown-caret-path" d="M24 11.305l-7.997 11.39L8 11.305z"></path></svg>
          </span>
          <div class="dropdown-content">
            <a class="dropdown-option" href="/">master</a><br>
            <a class="dropdown-option-active" href="/versions/1.7/">1.7</a><br>
            <a class="dropdown-option" href="/versions/1.6/">1.6</a><br>
            <a class="dropdown-option" href="/versions/1.5.0/">1.5.0</a><br>
            <a class="dropdown-option" href="/versions/1.4.1/">1.4.1</a><br>
            <a class="dropdown-option" href="/versions/1.3.1/">1.3.1</a><br>
            <a class="dropdown-option" href="/versions/1.2.1/">1.2.1</a><br>
            <a class="dropdown-option" href="/versions/1.1.0/">1.1.0</a><br>
            <a class="dropdown-option" href="/versions/1.0.0/">1.0.0</a><br>
            <a class="dropdown-option" href="/versions/0.12.1/">0.12.1</a><br>
            <a class="dropdown-option" href="/versions/0.11.0/">0.11.0</a>
          </div>
        </div>
      </div>
    </nav>
  </div>
</header>
    <div class="mdl-layout mdl-js-layout mdl-layout--fixed-header mdl-layout--fixed-drawer"><header class="mdl-layout__header mdl-layout__header--waterfall ">
    <div class="mdl-layout__header-row">
        
        <nav class="mdl-navigation breadcrumb">
            <a class="mdl-navigation__link" href="../../../index.html">Python Tutorials</a><i class="material-icons">navigate_next</i>
            <a class="mdl-navigation__link" href="../../index.html">Packages</a><i class="material-icons">navigate_next</i>
            <a class="mdl-navigation__link" href="../index.html">Gluon</a><i class="material-icons">navigate_next</i>
            <a class="mdl-navigation__link" href="index.html">Text Tutorials</a><i class="material-icons">navigate_next</i>
            <a class="mdl-navigation__link is-active">Google Neural Machine Translation</a>
        </nav>
        <div class="mdl-layout-spacer"></div>
        <nav class="mdl-navigation">
        
<form class="form-inline pull-sm-right" action="../../../../search.html" method="get">
      <div class="mdl-textfield mdl-js-textfield mdl-textfield--expandable mdl-textfield--floating-label mdl-textfield--align-right">
        <label id="quick-search-icon" class="mdl-button mdl-js-button mdl-button--icon"  for="waterfall-exp">
          <i class="material-icons">search</i>
        </label>
        <div class="mdl-textfield__expandable-holder">
          <input class="mdl-textfield__input" type="text" name="q"  id="waterfall-exp" placeholder="Search" />
          <input type="hidden" name="check_keywords" value="yes" />
          <input type="hidden" name="area" value="default" />
        </div>
      </div>
      <div class="mdl-tooltip" data-mdl-for="quick-search-icon">
      Quick search
      </div>
</form>
        
<a id="button-show-source"
    class="mdl-button mdl-js-button mdl-button--icon"
    href="../../../../_sources/tutorials/packages/gluon/text/gnmt.rst" rel="nofollow">
  <i class="material-icons">code</i>
</a>
<div class="mdl-tooltip" data-mdl-for="button-show-source">
Show Source
</div>
        </nav>
    </div>
    <div class="mdl-layout__header-row header-links">
      <div class="mdl-layout-spacer"></div>
      <nav class="mdl-navigation">
      </nav>
    </div>
</header><header class="mdl-layout__drawer">      
    
      <div class="globaltoc">
        <span class="mdl-layout-title toc">Table Of Contents</span>
        
        
            
            <nav class="mdl-navigation">
                <ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../../../index.html">Python Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../../../getting-started/index.html">Getting Started</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/crash-course/index.html">Crash Course</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/1-ndarray.html">Manipulate data with <code class="docutils literal notranslate"><span class="pre">ndarray</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/2-nn.html">Create a neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/3-autograd.html">Automatic differentiation with <code class="docutils literal notranslate"><span class="pre">autograd</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/4-train.html">Train the neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/5-predict.html">Predict with a pre-trained model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/6-use_gpus.html">Use GPUs</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/to-mxnet/index.html">Moving to MXNet from Other Frameworks</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/to-mxnet/pytorch.html">PyTorch vs Apache MXNet</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/gluon_from_experiment_to_deployment.html">Gluon: from experiment to deployment</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/logistic_regression_explained.html">Logistic regression explained</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html">MNIST</a></li>
</ul>
</li>
<li class="toctree-l2 current"><a class="reference internal" href="../../index.html">Packages</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../../autograd/index.html">Automatic Differentiation</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html">Gluon</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="../blocks/index.html">Blocks</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom-layer.html">Custom Layers</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom_layer_beginners.html">Customer Layers (Beginners)</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/hybridize.html">Hybridize</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/init.html">Initialization</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/naming.html">Parameter and Block Naming</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/nn.html">Layers and Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/parameters.html">Parameter Management</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/save_load_params.html">Saving and Loading Gluon Models</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/activations/activations.html">Activation Blocks</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../data/index.html">Data Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Spatial-Augmentation">Spatial Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Color-Augmentation">Color Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Composed-Augmentations">Composed Augmentations</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html">Gluon <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s and <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-included-Datasets">Using own data with included <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-custom-Datasets">Using own data with custom <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Appendix:-Upgrading-from-Module-DataIter-to-Gluon-DataLoader">Appendix: Upgrading from Module <code class="docutils literal notranslate"><span class="pre">DataIter</span></code> to Gluon <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../image/index.html">Image Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../image/image-augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../image/info_gan.html">Image similarity search with InfoGAN</a></li>
<li class="toctree-l5"><a class="reference internal" href="../image/mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l5"><a class="reference internal" href="../image/pretrained_models.html">Using pre-trained models in MXNet</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../loss/index.html">Losses</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../loss/custom-loss.html">Custom Loss Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/kl_divergence.html">Kullback-Leibler (KL) Divergence</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/loss.html">Loss functions</a></li>
</ul>
</li>
<li class="toctree-l4 current"><a class="reference internal" href="index.html">Text Tutorials</a><ul class="current">
<li class="toctree-l5 current"><a class="current reference internal" href="#">Google Neural Machine Translation</a></li>
<li class="toctree-l5"><a class="reference internal" href="transformer.html">Machine Translation with Transformer</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../training/index.html">Training</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../training/fit_api_tutorial.html">MXNet Gluon Fit API</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/trainer.html">Trainer</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/learning_rates/index.html">Learning Rates</a><ul>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_finder.html">Learning Rate Finder</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules.html">Learning Rate Schedules</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules_advanced.html">Advanced Learning Rate Schedules</a></li>
</ul>
</li>
<li class="toctree-l5"><a class="reference internal" href="../training/normalization/index.html">Normalization Blocks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../kvstore/index.html">KVStore</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../kvstore/kvstore.html">Distributed Key-Value Store</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../ndarray/index.html">NDArray</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/01-ndarray-intro.html">An Intro: Manipulate Data the MXNet Way with NDArray</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/02-ndarray-operations.html">NDArray Operations</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/03-ndarray-contexts.html">NDArray Contexts</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/gotchas_numpy_in_mxnet.html">Gotchas using NumPy in Apache MXNet</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/sparse/index.html">Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/csr.html">CSRNDArray - NDArray in Compressed Sparse Row Storage Format</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/row_sparse.html">RowSparseNDArray - NDArray for Sparse Gradient Updates</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train.html">Train a Linear Regression Model with Sparse Symbols</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train_gluon.html">Sparse NDArrays with Gluon</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../onnx/index.html">ONNX</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/fine_tuning_gluon.html">Fine-tuning an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/inference_on_onnx_model.html">Running inference on MXNet/Gluon from an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/super_resolution.html">Importing an ONNX model into MXNet</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/deploy/export/onnx.html">Export ONNX Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../optimizer/index.html">Optimizers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../viz/index.html">Visualization</a><ul>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/visualize_graph">Visualize networks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../performance/index.html">Performance</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/compression/index.html">Compression</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/compression/int8.html">Deploy with int-8</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/float16">Float16</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/gradient_compression">Gradient Compression</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/int8_inference.html">GluonCV with Quantized Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/backend/index.html">Accelerated Backend Tools</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/mkldnn/index.html">Intel MKL-DNN</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_quantization.html">Quantize with MKL-DNN backend</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_readme.html">Install MXNet with MKL-DNN</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tensorrt/index.html">TensorRT</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/tensorrt/tensorrt.html">Optimizing Deep Learning Computation Graphs with TensorRT</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tvm.html">Use TVM</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/profiler.html">Profiling MXNet Models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/amp.html">Using AMP: Automatic Mixed Precision</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../deploy/index.html">Deployment</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/export/index.html">Export</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/export/onnx.html">Exporting to ONNX format</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/export_network.html">Export Gluon CV Models</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html">Save / Load Parameters</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/inference/index.html">Inference</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/cpp.html">Deploy into C++</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/image_classification_jetson.html">Image Classication using pretrained ResNet-50 model on Jetson module</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/scala.html">Deploy into a Java or Scala Environment</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/wine_detector.html">Real-time Object Detection with MXNet On The Raspberry Pi</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/run-on-aws/index.html">Run on AWS</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_ec2.html">Run on an EC2 Instance</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_sagemaker.html">Run on Amazon SageMaker</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/cloud.html">MXNet on the Cloud</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../extend/index.html">Extend</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/custom_layer.html">Custom Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/customop.html">Custom Numpy Operators</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/new_op">New Operator Creation</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/add_op_in_backend">New Operator in MXNet Backend</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../../../api/index.html">Python API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/ndarray/index.html">mxnet.ndarray</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/ndarray.html">ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/contrib/index.html">ndarray.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/image/index.html">ndarray.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/linalg/index.html">ndarray.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/op/index.html">ndarray.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/random/index.html">ndarray.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/register/index.html">ndarray.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/sparse/index.html">ndarray.sparse</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/utils/index.html">ndarray.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/gluon/index.html">mxnet.gluon</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/block.html">gluon.Block</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/hybrid_block.html">gluon.HybridBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/symbol_block.html">gluon.SymbolBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/constant.html">gluon.Constant</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter.html">gluon.Parameter</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter_dict.html">gluon.ParameterDict</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/trainer.html">gluon.Trainer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/contrib/index.html">gluon.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/data/index.html">gluon.data</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../../api/gluon/data/vision/index.html">data.vision</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/datasets/index.html">vision.datasets</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/transforms/index.html">vision.transforms</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/loss/index.html">gluon.loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/model_zoo/index.html">gluon.model_zoo.vision</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/nn/index.html">gluon.nn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/rnn/index.html">gluon.rnn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/utils/index.html">gluon.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/autograd/index.html">mxnet.autograd</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/initializer/index.html">mxnet.initializer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/optimizer/index.html">mxnet.optimizer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/lr_scheduler/index.html">mxnet.lr_scheduler</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/metric/index.html">mxnet.metric</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/kvstore/index.html">mxnet.kvstore</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/symbol/index.html">mxnet.symbol</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/symbol.html">symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/contrib/index.html">symbol.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/image/index.html">symbol.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/linalg/index.html">symbol.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/op/index.html">symbol.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/random/index.html">symbol.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/register/index.html">symbol.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/sparse/index.html">symbol.sparse</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/module/index.html">mxnet.module</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/contrib/index.html">mxnet.contrib</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/autograd/index.html">contrib.autograd</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/io/index.html">contrib.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/ndarray/index.html">contrib.ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/onnx/index.html">contrib.onnx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/quantization/index.html">contrib.quantization</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/symbol/index.html">contrib.symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorboard/index.html">contrib.tensorboard</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorrt/index.html">contrib.tensorrt</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/text/index.html">contrib.text</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/mxnet/index.html">mxnet</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/attribute/index.html">mxnet.attribute</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/base/index.html">mxnet.base</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/callback/index.html">mxnet.callback</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/context/index.html">mxnet.context</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/engine/index.html">mxnet.engine</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor/index.html">mxnet.executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor_manager/index.html">mxnet.executor_manager</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/image/index.html">mxnet.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/io/index.html">mxnet.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/kvstore_server/index.html">mxnet.kvstore_server</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/libinfo/index.html">mxnet.libinfo</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/log/index.html">mxnet.log</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/model/index.html">mxnet.model</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/monitor/index.html">mxnet.monitor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/name/index.html">mxnet.name</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/notebook/index.html">mxnet.notebook</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/operator/index.html">mxnet.operator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/profiler/index.html">mxnet.profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/random/index.html">mxnet.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/recordio/index.html">mxnet.recordio</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/registry/index.html">mxnet.registry</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/rtc/index.html">mxnet.rtc</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/runtime/index.html">mxnet.runtime</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/test_utils/index.html">mxnet.test_utils</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/torch/index.html">mxnet.torch</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/util/index.html">mxnet.util</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/visualization/index.html">mxnet.visualization</a></li>
</ul>
</li>
</ul>
</li>
</ul>

            </nav>
        
        </div>
    
</header>
        <main class="mdl-layout__content" tabIndex="0">

        <script type="text/javascript" src="../../../../_static/sphinx_materialdesign_theme.js "></script>
        <script type="text/javascript" src="../../../../_static/feedback.js"></script>
    <header class="mdl-layout__drawer">      
    
      <div class="globaltoc">
        <span class="mdl-layout-title toc">Table Of Contents</span>
        
        
            
            <nav class="mdl-navigation">
                <ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="../../../index.html">Python Tutorials</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../../../getting-started/index.html">Getting Started</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/crash-course/index.html">Crash Course</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/1-ndarray.html">Manipulate data with <code class="docutils literal notranslate"><span class="pre">ndarray</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/2-nn.html">Create a neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/3-autograd.html">Automatic differentiation with <code class="docutils literal notranslate"><span class="pre">autograd</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/4-train.html">Train the neural network</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/5-predict.html">Predict with a pre-trained model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/crash-course/6-use_gpus.html">Use GPUs</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/to-mxnet/index.html">Moving to MXNet from Other Frameworks</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../getting-started/to-mxnet/pytorch.html">PyTorch vs Apache MXNet</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/gluon_from_experiment_to_deployment.html">Gluon: from experiment to deployment</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../getting-started/logistic_regression_explained.html">Logistic regression explained</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html">MNIST</a></li>
</ul>
</li>
<li class="toctree-l2 current"><a class="reference internal" href="../../index.html">Packages</a><ul class="current">
<li class="toctree-l3"><a class="reference internal" href="../../autograd/index.html">Automatic Differentiation</a></li>
<li class="toctree-l3 current"><a class="reference internal" href="../index.html">Gluon</a><ul class="current">
<li class="toctree-l4"><a class="reference internal" href="../blocks/index.html">Blocks</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom-layer.html">Custom Layers</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/custom_layer_beginners.html">Customer Layers (Beginners)</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/hybridize.html">Hybridize</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/init.html">Initialization</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/naming.html">Parameter and Block Naming</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/nn.html">Layers and Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/parameters.html">Parameter Management</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/save_load_params.html">Saving and Loading Gluon Models</a></li>
<li class="toctree-l5"><a class="reference internal" href="../blocks/activations/activations.html">Activation Blocks</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../data/index.html">Data Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Spatial-Augmentation">Spatial Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Color-Augmentation">Color Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/data_augmentation.html#Composed-Augmentations">Composed Augmentations</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html">Gluon <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s and <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-included-Datasets">Using own data with included <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Using-own-data-with-custom-Datasets">Using own data with custom <code class="docutils literal notranslate"><span class="pre">Dataset</span></code>s</a></li>
<li class="toctree-l5"><a class="reference internal" href="../data/datasets.html#Appendix:-Upgrading-from-Module-DataIter-to-Gluon-DataLoader">Appendix: Upgrading from Module <code class="docutils literal notranslate"><span class="pre">DataIter</span></code> to Gluon <code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../image/index.html">Image Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../image/image-augmentation.html">Image Augmentation</a></li>
<li class="toctree-l5"><a class="reference internal" href="../image/info_gan.html">Image similarity search with InfoGAN</a></li>
<li class="toctree-l5"><a class="reference internal" href="../image/mnist.html">Handwritten Digit Recognition</a></li>
<li class="toctree-l5"><a class="reference internal" href="../image/pretrained_models.html">Using pre-trained models in MXNet</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../loss/index.html">Losses</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../loss/custom-loss.html">Custom Loss Blocks</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/kl_divergence.html">Kullback-Leibler (KL) Divergence</a></li>
<li class="toctree-l5"><a class="reference internal" href="../loss/loss.html">Loss functions</a></li>
</ul>
</li>
<li class="toctree-l4 current"><a class="reference internal" href="index.html">Text Tutorials</a><ul class="current">
<li class="toctree-l5 current"><a class="current reference internal" href="#">Google Neural Machine Translation</a></li>
<li class="toctree-l5"><a class="reference internal" href="transformer.html">Machine Translation with Transformer</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../training/index.html">Training</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../training/fit_api_tutorial.html">MXNet Gluon Fit API</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/trainer.html">Trainer</a></li>
<li class="toctree-l5"><a class="reference internal" href="../training/learning_rates/index.html">Learning Rates</a><ul>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_finder.html">Learning Rate Finder</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules.html">Learning Rate Schedules</a></li>
<li class="toctree-l6"><a class="reference internal" href="../training/learning_rates/learning_rate_schedules_advanced.html">Advanced Learning Rate Schedules</a></li>
</ul>
</li>
<li class="toctree-l5"><a class="reference internal" href="../training/normalization/index.html">Normalization Blocks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../kvstore/index.html">KVStore</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../kvstore/kvstore.html">Distributed Key-Value Store</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../ndarray/index.html">NDArray</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/01-ndarray-intro.html">An Intro: Manipulate Data the MXNet Way with NDArray</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/02-ndarray-operations.html">NDArray Operations</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/03-ndarray-contexts.html">NDArray Contexts</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/gotchas_numpy_in_mxnet.html">Gotchas using NumPy in Apache MXNet</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../ndarray/sparse/index.html">Tutorials</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/csr.html">CSRNDArray - NDArray in Compressed Sparse Row Storage Format</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/row_sparse.html">RowSparseNDArray - NDArray for Sparse Gradient Updates</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train.html">Train a Linear Regression Model with Sparse Symbols</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../ndarray/sparse/train_gluon.html">Sparse NDArrays with Gluon</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../onnx/index.html">ONNX</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/fine_tuning_gluon.html">Fine-tuning an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/inference_on_onnx_model.html">Running inference on MXNet/Gluon from an ONNX model</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../onnx/super_resolution.html">Importing an ONNX model into MXNet</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/deploy/export/onnx.html">Export ONNX Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../optimizer/index.html">Optimizers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../viz/index.html">Visualization</a><ul>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/visualize_graph">Visualize networks</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../performance/index.html">Performance</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/compression/index.html">Compression</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/compression/int8.html">Deploy with int-8</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/float16">Float16</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/faq/gradient_compression">Gradient Compression</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/int8_inference.html">GluonCV with Quantized Models</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../performance/backend/index.html">Accelerated Backend Tools</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/mkldnn/index.html">Intel MKL-DNN</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_quantization.html">Quantize with MKL-DNN backend</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/mkldnn/mkldnn_readme.html">Install MXNet with MKL-DNN</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tensorrt/index.html">TensorRT</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../performance/backend/tensorrt/tensorrt.html">Optimizing Deep Learning Computation Graphs with TensorRT</a></li>
</ul>
</li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/tvm.html">Use TVM</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/profiler.html">Profiling MXNet Models</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../performance/backend/amp.html">Using AMP: Automatic Mixed Precision</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../deploy/index.html">Deployment</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/export/index.html">Export</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/export/onnx.html">Exporting to ONNX format</a></li>
<li class="toctree-l4"><a class="reference external" href="https://gluon-cv.mxnet.io/build/examples_deployment/export_network.html">Export Gluon CV Models</a></li>
<li class="toctree-l4"><a class="reference external" href="https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html">Save / Load Parameters</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/inference/index.html">Inference</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/cpp.html">Deploy into C++</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/image_classification_jetson.html">Image Classication using pretrained ResNet-50 model on Jetson module</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/scala.html">Deploy into a Java or Scala Environment</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/inference/wine_detector.html">Real-time Object Detection with MXNet On The Raspberry Pi</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../deploy/run-on-aws/index.html">Run on AWS</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_ec2.html">Run on an EC2 Instance</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/use_sagemaker.html">Run on Amazon SageMaker</a></li>
<li class="toctree-l4"><a class="reference internal" href="../../../deploy/run-on-aws/cloud.html">MXNet on the Cloud</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../extend/index.html">Extend</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/custom_layer.html">Custom Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../extend/customop.html">Custom Numpy Operators</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/new_op">New Operator Creation</a></li>
<li class="toctree-l3"><a class="reference external" href="https://mxnet.apache.org/api/faq/add_op_in_backend">New Operator in MXNet Backend</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../../../api/index.html">Python API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/ndarray/index.html">mxnet.ndarray</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/ndarray.html">ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/contrib/index.html">ndarray.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/image/index.html">ndarray.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/linalg/index.html">ndarray.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/op/index.html">ndarray.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/random/index.html">ndarray.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/register/index.html">ndarray.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/sparse/index.html">ndarray.sparse</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/ndarray/utils/index.html">ndarray.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/gluon/index.html">mxnet.gluon</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/block.html">gluon.Block</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/hybrid_block.html">gluon.HybridBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/symbol_block.html">gluon.SymbolBlock</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/constant.html">gluon.Constant</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter.html">gluon.Parameter</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/parameter_dict.html">gluon.ParameterDict</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/trainer.html">gluon.Trainer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/contrib/index.html">gluon.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/data/index.html">gluon.data</a><ul>
<li class="toctree-l4"><a class="reference internal" href="../../../../api/gluon/data/vision/index.html">data.vision</a><ul>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/datasets/index.html">vision.datasets</a></li>
<li class="toctree-l5"><a class="reference internal" href="../../../../api/gluon/data/vision/transforms/index.html">vision.transforms</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/loss/index.html">gluon.loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/model_zoo/index.html">gluon.model_zoo.vision</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/nn/index.html">gluon.nn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/rnn/index.html">gluon.rnn</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/gluon/utils/index.html">gluon.utils</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/autograd/index.html">mxnet.autograd</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/initializer/index.html">mxnet.initializer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/optimizer/index.html">mxnet.optimizer</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/lr_scheduler/index.html">mxnet.lr_scheduler</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/metric/index.html">mxnet.metric</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/kvstore/index.html">mxnet.kvstore</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/symbol/index.html">mxnet.symbol</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/symbol.html">symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/contrib/index.html">symbol.contrib</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/image/index.html">symbol.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/linalg/index.html">symbol.linalg</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/op/index.html">symbol.op</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/random/index.html">symbol.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/register/index.html">symbol.register</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/symbol/sparse/index.html">symbol.sparse</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/module/index.html">mxnet.module</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/contrib/index.html">mxnet.contrib</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/autograd/index.html">contrib.autograd</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/io/index.html">contrib.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/ndarray/index.html">contrib.ndarray</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/onnx/index.html">contrib.onnx</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/quantization/index.html">contrib.quantization</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/symbol/index.html">contrib.symbol</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorboard/index.html">contrib.tensorboard</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/tensorrt/index.html">contrib.tensorrt</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/contrib/text/index.html">contrib.text</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../../../api/mxnet/index.html">mxnet</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/attribute/index.html">mxnet.attribute</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/base/index.html">mxnet.base</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/callback/index.html">mxnet.callback</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/context/index.html">mxnet.context</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/engine/index.html">mxnet.engine</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor/index.html">mxnet.executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/executor_manager/index.html">mxnet.executor_manager</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/image/index.html">mxnet.image</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/io/index.html">mxnet.io</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/kvstore_server/index.html">mxnet.kvstore_server</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/libinfo/index.html">mxnet.libinfo</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/log/index.html">mxnet.log</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/model/index.html">mxnet.model</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/monitor/index.html">mxnet.monitor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/name/index.html">mxnet.name</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/notebook/index.html">mxnet.notebook</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/operator/index.html">mxnet.operator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/profiler/index.html">mxnet.profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/random/index.html">mxnet.random</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/recordio/index.html">mxnet.recordio</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/registry/index.html">mxnet.registry</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/rtc/index.html">mxnet.rtc</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/runtime/index.html">mxnet.runtime</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/test_utils/index.html">mxnet.test_utils</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/torch/index.html">mxnet.torch</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/util/index.html">mxnet.util</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../../../api/mxnet/visualization/index.html">mxnet.visualization</a></li>
</ul>
</li>
</ul>
</li>
</ul>

            </nav>
        
        </div>
    
</header>

    <div class="document">
        <div class="page-content" role="main">
        
  <div class="section" id="google-neural-machine-translation">
<h1>Google Neural Machine Translation<a class="headerlink" href="#google-neural-machine-translation" title="Permalink to this headline">¶</a></h1>
<p>In this notebook, we are going to train Google NMT on IWSLT 2015
English-Vietnamese Dataset. The building process includes four steps: 1)
load and process dataset, 2) create sampler and DataLoader, 3) build
model, and 4) write training epochs.</p>
<div class="section" id="load-mxnet-and-gluon">
<h2>Load MXNET and Gluon<a class="headerlink" href="#load-mxnet-and-gluon" title="Permalink to this headline">¶</a></h2>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">warnings</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">filterwarnings</span><span class="p">(</span><span class="s1">&#39;ignore&#39;</span><span class="p">)</span>

<span class="kn">import</span> <span class="nn">argparse</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">logging</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">mxnet</span> <span class="k">as</span> <span class="nn">mx</span>
<span class="kn">from</span> <span class="nn">mxnet</span> <span class="kn">import</span> <span class="n">gluon</span>
<span class="kn">import</span> <span class="nn">gluonnlp</span> <span class="k">as</span> <span class="nn">nlp</span>
<span class="kn">import</span> <span class="nn">nmt</span>
</pre></div>
</div>
</div>
<div class="section" id="hyper-parameters">
<h2>Hyper-parameters<a class="headerlink" href="#hyper-parameters" title="Permalink to this headline">¶</a></h2>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">100</span><span class="p">)</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">100</span><span class="p">)</span>
<span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span>
<span class="n">ctx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

<span class="c1"># parameters for dataset</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="s1">&#39;IWSLT2015&#39;</span>
<span class="n">src_lang</span><span class="p">,</span> <span class="n">tgt_lang</span> <span class="o">=</span> <span class="s1">&#39;en&#39;</span><span class="p">,</span> <span class="s1">&#39;vi&#39;</span>
<span class="n">src_max_len</span><span class="p">,</span> <span class="n">tgt_max_len</span> <span class="o">=</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">50</span>

<span class="c1"># parameters for model</span>
<span class="n">num_hidden</span> <span class="o">=</span> <span class="mi">512</span>
<span class="n">num_layers</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">num_bi_layers</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">dropout</span> <span class="o">=</span> <span class="mf">0.2</span>

<span class="c1"># parameters for training</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">test_batch_size</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">32</span>
<span class="n">num_buckets</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">epochs</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">clip</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.001</span>
<span class="n">lr_update_factor</span> <span class="o">=</span> <span class="mf">0.5</span>
<span class="n">log_interval</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">save_dir</span> <span class="o">=</span> <span class="s1">&#39;gnmt_en_vi_u512&#39;</span>

<span class="c1">#parameters for testing</span>
<span class="n">beam_size</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">lp_alpha</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">lp_k</span> <span class="o">=</span> <span class="mi">5</span>

<span class="n">nmt</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">logging_config</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="load-and-preprocess-dataset">
<h2>Load and Preprocess Dataset<a class="headerlink" href="#load-and-preprocess-dataset" title="Permalink to this headline">¶</a></h2>
<p>The following shows how to process the dataset and cache the processed
dataset for future use. The processing steps include: 1) clip the source
and target sequences, 2) split the string input to a list of tokens, 3)
map the string token into its integer index in the vocabulary, and 4)
append end-of-sentence (EOS) token to source sentence and add BOS and
EOS tokens to target sentence.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">cache_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">prefix</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Cache the processed npy dataset  the dataset into a npz</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    dataset : gluon.data.SimpleDataset</span>
<span class="sd">    file_path : str</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">nmt</span><span class="o">.</span><span class="n">_constants</span><span class="o">.</span><span class="n">CACHE_PATH</span><span class="p">):</span>
        <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">nmt</span><span class="o">.</span><span class="n">_constants</span><span class="o">.</span><span class="n">CACHE_PATH</span><span class="p">)</span>
    <span class="n">src_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">ele</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">ele</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">])</span>
    <span class="n">tgt_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">ele</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">ele</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">])</span>
    <span class="n">np</span><span class="o">.</span><span class="n">savez</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">nmt</span><span class="o">.</span><span class="n">_constants</span><span class="o">.</span><span class="n">CACHE_PATH</span><span class="p">,</span> <span class="n">prefix</span> <span class="o">+</span> <span class="s1">&#39;.npz&#39;</span><span class="p">),</span> <span class="n">src_data</span><span class="o">=</span><span class="n">src_data</span><span class="p">,</span> <span class="n">tgt_data</span><span class="o">=</span><span class="n">tgt_data</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">load_cached_dataset</span><span class="p">(</span><span class="n">prefix</span><span class="p">):</span>
    <span class="n">cached_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">nmt</span><span class="o">.</span><span class="n">_constants</span><span class="o">.</span><span class="n">CACHE_PATH</span><span class="p">,</span> <span class="n">prefix</span> <span class="o">+</span> <span class="s1">&#39;.npz&#39;</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">cached_file_path</span><span class="p">):</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Load cached data from </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">cached_file_path</span><span class="p">))</span>
        <span class="n">dat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">cached_file_path</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ArrayDataset</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dat</span><span class="p">[</span><span class="s1">&#39;src_data&#39;</span><span class="p">]),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dat</span><span class="p">[</span><span class="s1">&#39;tgt_data&#39;</span><span class="p">]))</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">return</span> <span class="kc">None</span>


<span class="k">class</span> <span class="nc">TrainValDataTransform</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Transform the machine translation dataset.</span>

<span class="sd">    Clip source and the target sentences to the maximum length. For the source sentence, append the</span>
<span class="sd">    EOS. For the target sentence, append BOS and EOS.</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    src_vocab : Vocab</span>
<span class="sd">    tgt_vocab : Vocab</span>
<span class="sd">    src_max_len : int</span>
<span class="sd">    tgt_max_len : int</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">,</span> <span class="n">src_max_len</span><span class="p">,</span> <span class="n">tgt_max_len</span><span class="p">):</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_src_vocab</span> <span class="o">=</span> <span class="n">src_vocab</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span> <span class="o">=</span> <span class="n">tgt_vocab</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_src_max_len</span> <span class="o">=</span> <span class="n">src_max_len</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_tgt_max_len</span> <span class="o">=</span> <span class="n">tgt_max_len</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">):</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_src_max_len</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">src_sentence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_src_vocab</span><span class="p">[</span><span class="n">src</span><span class="o">.</span><span class="n">split</span><span class="p">()[:</span><span class="bp">self</span><span class="o">.</span><span class="n">_src_max_len</span><span class="p">]]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">src_sentence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_src_vocab</span><span class="p">[</span><span class="n">src</span><span class="o">.</span><span class="n">split</span><span class="p">()]</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tgt_max_len</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">tgt_sentence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span><span class="p">[</span><span class="n">tgt</span><span class="o">.</span><span class="n">split</span><span class="p">()[:</span><span class="bp">self</span><span class="o">.</span><span class="n">_tgt_max_len</span><span class="p">]]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">tgt_sentence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span><span class="p">[</span><span class="n">tgt</span><span class="o">.</span><span class="n">split</span><span class="p">()]</span>
        <span class="n">src_sentence</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_src_vocab</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_src_vocab</span><span class="o">.</span><span class="n">eos_token</span><span class="p">])</span>
        <span class="n">tgt_sentence</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span><span class="o">.</span><span class="n">bos_token</span><span class="p">])</span>
        <span class="n">tgt_sentence</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_tgt_vocab</span><span class="o">.</span><span class="n">eos_token</span><span class="p">])</span>
        <span class="n">src_npy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">src_sentence</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
        <span class="n">tgt_npy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tgt_sentence</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">src_npy</span><span class="p">,</span> <span class="n">tgt_npy</span>


<span class="k">def</span> <span class="nf">process_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">,</span> <span class="n">src_max_len</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">tgt_max_len</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
    <span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    <span class="n">dataset_processed</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">TrainValDataTransform</span><span class="p">(</span><span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">,</span>
                                                                <span class="n">src_max_len</span><span class="p">,</span>
                                                                <span class="n">tgt_max_len</span><span class="p">),</span> <span class="n">lazy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    <span class="n">end</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Processing time spent: </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">dataset_processed</span>


<span class="k">def</span> <span class="nf">load_translation_data</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">src_lang</span><span class="o">=</span><span class="s1">&#39;en&#39;</span><span class="p">,</span> <span class="n">tgt_lang</span><span class="o">=</span><span class="s1">&#39;vi&#39;</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Load translation dataset</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    dataset : str</span>
<span class="sd">    src_lang : str, default &#39;en&#39;</span>
<span class="sd">    tgt_lang : str, default &#39;vi&#39;</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    data_train_processed : Dataset</span>
<span class="sd">        The preprocessed training sentence pairs</span>
<span class="sd">    data_val_processed : Dataset</span>
<span class="sd">        The preprocessed validation sentence pairs</span>
<span class="sd">    data_test_processed : Dataset</span>
<span class="sd">        The preprocessed test sentence pairs</span>
<span class="sd">    val_tgt_sentences : list</span>
<span class="sd">        The target sentences in the validation set</span>
<span class="sd">    test_tgt_sentences : list</span>
<span class="sd">        The target sentences in the test set</span>
<span class="sd">    src_vocab : Vocab</span>
<span class="sd">        Vocabulary of the source language</span>
<span class="sd">    tgt_vocab : Vocab</span>
<span class="sd">        Vocabulary of the target language</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">common_prefix</span> <span class="o">=</span> <span class="s1">&#39;IWSLT2015_</span><span class="si">{}</span><span class="s1">_</span><span class="si">{}</span><span class="s1">_</span><span class="si">{}</span><span class="s1">_</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">src_lang</span><span class="p">,</span> <span class="n">tgt_lang</span><span class="p">,</span>
                                                   <span class="n">src_max_len</span><span class="p">,</span> <span class="n">tgt_max_len</span><span class="p">)</span>
    <span class="n">data_train</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">IWSLT2015</span><span class="p">(</span><span class="s1">&#39;train&#39;</span><span class="p">,</span> <span class="n">src_lang</span><span class="o">=</span><span class="n">src_lang</span><span class="p">,</span> <span class="n">tgt_lang</span><span class="o">=</span><span class="n">tgt_lang</span><span class="p">)</span>
    <span class="n">data_val</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">IWSLT2015</span><span class="p">(</span><span class="s1">&#39;val&#39;</span><span class="p">,</span> <span class="n">src_lang</span><span class="o">=</span><span class="n">src_lang</span><span class="p">,</span> <span class="n">tgt_lang</span><span class="o">=</span><span class="n">tgt_lang</span><span class="p">)</span>
    <span class="n">data_test</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">IWSLT2015</span><span class="p">(</span><span class="s1">&#39;test&#39;</span><span class="p">,</span> <span class="n">src_lang</span><span class="o">=</span><span class="n">src_lang</span><span class="p">,</span> <span class="n">tgt_lang</span><span class="o">=</span><span class="n">tgt_lang</span><span class="p">)</span>
    <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span> <span class="o">=</span> <span class="n">data_train</span><span class="o">.</span><span class="n">src_vocab</span><span class="p">,</span> <span class="n">data_train</span><span class="o">.</span><span class="n">tgt_vocab</span>
    <span class="n">data_train_processed</span> <span class="o">=</span> <span class="n">load_cached_dataset</span><span class="p">(</span><span class="n">common_prefix</span> <span class="o">+</span> <span class="s1">&#39;_train&#39;</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">data_train_processed</span><span class="p">:</span>
        <span class="n">data_train_processed</span> <span class="o">=</span> <span class="n">process_dataset</span><span class="p">(</span><span class="n">data_train</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">,</span>
                                               <span class="n">src_max_len</span><span class="p">,</span> <span class="n">tgt_max_len</span><span class="p">)</span>
        <span class="n">cache_dataset</span><span class="p">(</span><span class="n">data_train_processed</span><span class="p">,</span> <span class="n">common_prefix</span> <span class="o">+</span> <span class="s1">&#39;_train&#39;</span><span class="p">)</span>
    <span class="n">data_val_processed</span> <span class="o">=</span> <span class="n">load_cached_dataset</span><span class="p">(</span><span class="n">common_prefix</span> <span class="o">+</span> <span class="s1">&#39;_val&#39;</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">data_val_processed</span><span class="p">:</span>
        <span class="n">data_val_processed</span> <span class="o">=</span> <span class="n">process_dataset</span><span class="p">(</span><span class="n">data_val</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">)</span>
        <span class="n">cache_dataset</span><span class="p">(</span><span class="n">data_val_processed</span><span class="p">,</span> <span class="n">common_prefix</span> <span class="o">+</span> <span class="s1">&#39;_val&#39;</span><span class="p">)</span>
    <span class="n">data_test_processed</span> <span class="o">=</span> <span class="n">load_cached_dataset</span><span class="p">(</span><span class="n">common_prefix</span> <span class="o">+</span> <span class="s1">&#39;_test&#39;</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">data_test_processed</span><span class="p">:</span>
        <span class="n">data_test_processed</span> <span class="o">=</span> <span class="n">process_dataset</span><span class="p">(</span><span class="n">data_test</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">)</span>
        <span class="n">cache_dataset</span><span class="p">(</span><span class="n">data_test_processed</span><span class="p">,</span> <span class="n">common_prefix</span> <span class="o">+</span> <span class="s1">&#39;_test&#39;</span><span class="p">)</span>
    <span class="n">fetch_tgt_sentence</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">tgt</span><span class="o">.</span><span class="n">split</span><span class="p">()</span>
    <span class="n">val_tgt_sentences</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data_val</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">fetch_tgt_sentence</span><span class="p">))</span>
    <span class="n">test_tgt_sentences</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data_test</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">fetch_tgt_sentence</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">data_train_processed</span><span class="p">,</span> <span class="n">data_val_processed</span><span class="p">,</span> <span class="n">data_test_processed</span><span class="p">,</span> \
           <span class="n">val_tgt_sentences</span><span class="p">,</span> <span class="n">test_tgt_sentences</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span>


<span class="k">def</span> <span class="nf">get_data_lengths</span><span class="p">(</span><span class="n">dataset</span><span class="p">):</span>
    <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="k">lambda</span> <span class="n">srg</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">srg</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">tgt</span><span class="p">))))</span>


<span class="n">data_train</span><span class="p">,</span> <span class="n">data_val</span><span class="p">,</span> <span class="n">data_test</span><span class="p">,</span> <span class="n">val_tgt_sentences</span><span class="p">,</span> <span class="n">test_tgt_sentences</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span>\
    <span class="o">=</span> <span class="n">load_translation_data</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">src_lang</span><span class="o">=</span><span class="n">src_lang</span><span class="p">,</span> <span class="n">tgt_lang</span><span class="o">=</span><span class="n">tgt_lang</span><span class="p">)</span>
<span class="n">data_train_lengths</span> <span class="o">=</span> <span class="n">get_data_lengths</span><span class="p">(</span><span class="n">data_train</span><span class="p">)</span>
<span class="n">data_val_lengths</span> <span class="o">=</span> <span class="n">get_data_lengths</span><span class="p">(</span><span class="n">data_val</span><span class="p">)</span>
<span class="n">data_test_lengths</span> <span class="o">=</span> <span class="n">get_data_lengths</span><span class="p">(</span><span class="n">data_test</span><span class="p">)</span>

<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_dir</span><span class="p">,</span> <span class="s1">&#39;val_gt.txt&#39;</span><span class="p">),</span> <span class="s1">&#39;w&#39;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;utf-8&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">of</span><span class="p">:</span>
    <span class="k">for</span> <span class="n">ele</span> <span class="ow">in</span> <span class="n">val_tgt_sentences</span><span class="p">:</span>
        <span class="n">of</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ele</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>

<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_dir</span><span class="p">,</span> <span class="s1">&#39;test_gt.txt&#39;</span><span class="p">),</span> <span class="s1">&#39;w&#39;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;utf-8&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">of</span><span class="p">:</span>
    <span class="k">for</span> <span class="n">ele</span> <span class="ow">in</span> <span class="n">test_tgt_sentences</span><span class="p">:</span>
        <span class="n">of</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ele</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>


<span class="n">data_train</span> <span class="o">=</span> <span class="n">data_train</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="k">lambda</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">tgt</span><span class="p">)),</span> <span class="n">lazy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">data_val</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">SimpleDataset</span><span class="p">([(</span><span class="n">ele</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ele</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="nb">len</span><span class="p">(</span><span class="n">ele</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">len</span><span class="p">(</span><span class="n">ele</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">i</span><span class="p">)</span>
                                     <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ele</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data_val</span><span class="p">)])</span>
<span class="n">data_test</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">SimpleDataset</span><span class="p">([(</span><span class="n">ele</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ele</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="nb">len</span><span class="p">(</span><span class="n">ele</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">len</span><span class="p">(</span><span class="n">ele</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">i</span><span class="p">)</span>
                                      <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ele</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data_test</span><span class="p">)])</span>
</pre></div>
</div>
</div>
<div class="section" id="create-sampler-and-dataloader">
<h2>Create Sampler and DataLoader<a class="headerlink" href="#create-sampler-and-dataloader" title="Permalink to this headline">¶</a></h2>
<p>Now, we have obtained <code class="docutils literal notranslate"><span class="pre">data_train</span></code>, <code class="docutils literal notranslate"><span class="pre">data_val</span></code>, and <code class="docutils literal notranslate"><span class="pre">data_test</span></code>.
The next step is to construct sampler and DataLoader. The first step is
to construct batchify function, which pads and stacks sequences to form
mini-batch.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">train_batchify_fn</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Tuple</span><span class="p">(</span><span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Pad</span><span class="p">(),</span>
                                            <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Pad</span><span class="p">(),</span>
                                            <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Stack</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">),</span>
                                            <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Stack</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">))</span>
<span class="n">test_batchify_fn</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Tuple</span><span class="p">(</span><span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Pad</span><span class="p">(),</span>
                                           <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Pad</span><span class="p">(),</span>
                                           <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Stack</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">),</span>
                                           <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Stack</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;float32&#39;</span><span class="p">),</span>
                                           <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batchify</span><span class="o">.</span><span class="n">Stack</span><span class="p">())</span>
</pre></div>
</div>
<p>We can then construct bucketing samplers, which generate batches by
grouping sequences with similar lengths. Here, the bucketing scheme is
empirically determined.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">bucket_scheme</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ExpWidthBucket</span><span class="p">(</span><span class="n">bucket_len_step</span><span class="o">=</span><span class="mf">1.2</span><span class="p">)</span>
<span class="n">train_batch_sampler</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">FixedBucketSampler</span><span class="p">(</span><span class="n">lengths</span><span class="o">=</span><span class="n">data_train_lengths</span><span class="p">,</span>
                                                  <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
                                                  <span class="n">num_buckets</span><span class="o">=</span><span class="n">num_buckets</span><span class="p">,</span>
                                                  <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                                  <span class="n">bucket_scheme</span><span class="o">=</span><span class="n">bucket_scheme</span><span class="p">)</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Train Batch Sampler:</span><span class="se">\n</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">train_batch_sampler</span><span class="o">.</span><span class="n">stats</span><span class="p">()))</span>
<span class="n">val_batch_sampler</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">FixedBucketSampler</span><span class="p">(</span><span class="n">lengths</span><span class="o">=</span><span class="n">data_val_lengths</span><span class="p">,</span>
                                                <span class="n">batch_size</span><span class="o">=</span><span class="n">test_batch_size</span><span class="p">,</span>
                                                <span class="n">num_buckets</span><span class="o">=</span><span class="n">num_buckets</span><span class="p">,</span>
                                                <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Valid Batch Sampler:</span><span class="se">\n</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">val_batch_sampler</span><span class="o">.</span><span class="n">stats</span><span class="p">()))</span>
<span class="n">test_batch_sampler</span> <span class="o">=</span> <span class="n">nlp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">FixedBucketSampler</span><span class="p">(</span><span class="n">lengths</span><span class="o">=</span><span class="n">data_test_lengths</span><span class="p">,</span>
                                                 <span class="n">batch_size</span><span class="o">=</span><span class="n">test_batch_size</span><span class="p">,</span>
                                                 <span class="n">num_buckets</span><span class="o">=</span><span class="n">num_buckets</span><span class="p">,</span>
                                                 <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Test Batch Sampler:</span><span class="se">\n</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">test_batch_sampler</span><span class="o">.</span><span class="n">stats</span><span class="p">()))</span>
</pre></div>
</div>
<p>Given the samplers, we can create DataLoader, which is iterable.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">train_data_loader</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">data_train</span><span class="p">,</span>
                                          <span class="n">batch_sampler</span><span class="o">=</span><span class="n">train_batch_sampler</span><span class="p">,</span>
                                          <span class="n">batchify_fn</span><span class="o">=</span><span class="n">train_batchify_fn</span><span class="p">,</span>
                                          <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<span class="n">val_data_loader</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">data_val</span><span class="p">,</span>
                                        <span class="n">batch_sampler</span><span class="o">=</span><span class="n">val_batch_sampler</span><span class="p">,</span>
                                        <span class="n">batchify_fn</span><span class="o">=</span><span class="n">test_batchify_fn</span><span class="p">,</span>
                                        <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<span class="n">test_data_loader</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">data_test</span><span class="p">,</span>
                                         <span class="n">batch_sampler</span><span class="o">=</span><span class="n">test_batch_sampler</span><span class="p">,</span>
                                         <span class="n">batchify_fn</span><span class="o">=</span><span class="n">test_batchify_fn</span><span class="p">,</span>
                                         <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="build-gnmt-model">
<h2>Build GNMT Model<a class="headerlink" href="#build-gnmt-model" title="Permalink to this headline">¶</a></h2>
<p>After obtaining DataLoader, we can build the model. The GNMT encoder and
decoder can be easily constructed by calling
<code class="docutils literal notranslate"><span class="pre">get_gnmt_encoder_decoder</span></code> function. Then, we feed the encoder and
decoder to <code class="docutils literal notranslate"><span class="pre">NMTModel</span></code> to construct the GNMT model. <code class="docutils literal notranslate"><span class="pre">model.hybridize</span></code>
allows computation to be done using the symbolic backend.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">nmt</span><span class="o">.</span><span class="n">gnmt</span><span class="o">.</span><span class="n">get_gnmt_encoder_decoder</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">=</span><span class="n">num_hidden</span><span class="p">,</span>
                                                     <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span>
                                                     <span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">,</span>
                                                     <span class="n">num_bi_layers</span><span class="o">=</span><span class="n">num_bi_layers</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">nmt</span><span class="o">.</span><span class="n">translation</span><span class="o">.</span><span class="n">NMTModel</span><span class="p">(</span><span class="n">src_vocab</span><span class="o">=</span><span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="o">=</span><span class="n">tgt_vocab</span><span class="p">,</span> <span class="n">encoder</span><span class="o">=</span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="o">=</span><span class="n">decoder</span><span class="p">,</span>
                                 <span class="n">embed_size</span><span class="o">=</span><span class="n">num_hidden</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="s1">&#39;gnmt_&#39;</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">initialize</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">Uniform</span><span class="p">(</span><span class="mf">0.1</span><span class="p">),</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
<span class="n">static_alloc</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">model</span><span class="o">.</span><span class="n">hybridize</span><span class="p">(</span><span class="n">static_alloc</span><span class="o">=</span><span class="n">static_alloc</span><span class="p">)</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>

<span class="c1"># Due to the paddings, we need to mask out the losses corresponding to padding tokens.</span>
<span class="n">loss_function</span> <span class="o">=</span> <span class="n">nmt</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">SoftmaxCEMaskedLoss</span><span class="p">()</span>
<span class="n">loss_function</span><span class="o">.</span><span class="n">hybridize</span><span class="p">(</span><span class="n">static_alloc</span><span class="o">=</span><span class="n">static_alloc</span><span class="p">)</span>
</pre></div>
</div>
<p>We also build the beam search translator.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">translator</span> <span class="o">=</span> <span class="n">nmt</span><span class="o">.</span><span class="n">translation</span><span class="o">.</span><span class="n">BeamSearchTranslator</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">beam_size</span><span class="o">=</span><span class="n">beam_size</span><span class="p">,</span>
                                                  <span class="n">scorer</span><span class="o">=</span><span class="n">nlp</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">BeamSearchScorer</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="n">lp_alpha</span><span class="p">,</span>
                                                                                    <span class="n">K</span><span class="o">=</span><span class="n">lp_k</span><span class="p">),</span>
                                                  <span class="n">max_length</span><span class="o">=</span><span class="n">tgt_max_len</span> <span class="o">+</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Use beam_size=</span><span class="si">{}</span><span class="s1">, alpha=</span><span class="si">{}</span><span class="s1">, K=</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">beam_size</span><span class="p">,</span> <span class="n">lp_alpha</span><span class="p">,</span> <span class="n">lp_k</span><span class="p">))</span>
</pre></div>
</div>
<p>We define evaluation function as follows. The <code class="docutils literal notranslate"><span class="pre">evaluate</span></code> function use
beam search translator to generate outputs for the validation and
testing datasets.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="n">data_loader</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Evaluate given the data loader</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    data_loader : gluon.data.DataLoader</span>

<span class="sd">    Returns</span>
<span class="sd">    -------</span>
<span class="sd">    avg_loss : float</span>
<span class="sd">        Average loss</span>
<span class="sd">    real_translation_out : list of list of str</span>
<span class="sd">        The translation output</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">translation_out</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">all_inst_ids</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">avg_loss_denom</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">avg_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
    <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="p">(</span><span class="n">src_seq</span><span class="p">,</span> <span class="n">tgt_seq</span><span class="p">,</span> <span class="n">src_valid_length</span><span class="p">,</span> <span class="n">tgt_valid_length</span><span class="p">,</span> <span class="n">inst_ids</span><span class="p">)</span> \
            <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data_loader</span><span class="p">):</span>
        <span class="n">src_seq</span> <span class="o">=</span> <span class="n">src_seq</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="n">tgt_seq</span> <span class="o">=</span> <span class="n">tgt_seq</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="n">src_valid_length</span> <span class="o">=</span> <span class="n">src_valid_length</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="n">tgt_valid_length</span> <span class="o">=</span> <span class="n">tgt_valid_length</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="c1"># Calculating Loss</span>
        <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">src_seq</span><span class="p">,</span> <span class="n">tgt_seq</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">src_valid_length</span><span class="p">,</span> <span class="n">tgt_valid_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_function</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">tgt_seq</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="n">tgt_valid_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
        <span class="n">all_inst_ids</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">inst_ids</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
        <span class="n">avg_loss</span> <span class="o">+=</span> <span class="n">loss</span> <span class="o">*</span> <span class="p">(</span><span class="n">tgt_seq</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">avg_loss_denom</span> <span class="o">+=</span> <span class="p">(</span><span class="n">tgt_seq</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
        <span class="c1"># Translate</span>
        <span class="n">samples</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">sample_valid_length</span> <span class="o">=</span>\
            <span class="n">translator</span><span class="o">.</span><span class="n">translate</span><span class="p">(</span><span class="n">src_seq</span><span class="o">=</span><span class="n">src_seq</span><span class="p">,</span> <span class="n">src_valid_length</span><span class="o">=</span><span class="n">src_valid_length</span><span class="p">)</span>
        <span class="n">max_score_sample</span> <span class="o">=</span> <span class="n">samples</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
        <span class="n">sample_valid_length</span> <span class="o">=</span> <span class="n">sample_valid_length</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">asnumpy</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_score_sample</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
            <span class="n">translation_out</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
                <span class="p">[</span><span class="n">tgt_vocab</span><span class="o">.</span><span class="n">idx_to_token</span><span class="p">[</span><span class="n">ele</span><span class="p">]</span> <span class="k">for</span> <span class="n">ele</span> <span class="ow">in</span>
                 <span class="n">max_score_sample</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">:(</span><span class="n">sample_valid_length</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]])</span>
    <span class="n">avg_loss</span> <span class="o">=</span> <span class="n">avg_loss</span> <span class="o">/</span> <span class="n">avg_loss_denom</span>
    <span class="n">real_translation_out</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">all_inst_ids</span><span class="p">))]</span>
    <span class="k">for</span> <span class="n">ind</span><span class="p">,</span> <span class="n">sentence</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">all_inst_ids</span><span class="p">,</span> <span class="n">translation_out</span><span class="p">):</span>
        <span class="n">real_translation_out</span><span class="p">[</span><span class="n">ind</span><span class="p">]</span> <span class="o">=</span> <span class="n">sentence</span>
    <span class="k">return</span> <span class="n">avg_loss</span><span class="p">,</span> <span class="n">real_translation_out</span>


<span class="k">def</span> <span class="nf">write_sentences</span><span class="p">(</span><span class="n">sentences</span><span class="p">,</span> <span class="n">file_path</span><span class="p">):</span>
    <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;utf-8&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">of</span><span class="p">:</span>
        <span class="k">for</span> <span class="n">sent</span> <span class="ow">in</span> <span class="n">sentences</span><span class="p">:</span>
            <span class="n">of</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sent</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="training-epochs">
<h2>Training Epochs<a class="headerlink" href="#training-epochs" title="Permalink to this headline">¶</a></h2>
<p>Before entering the training stage, we need to create trainer for
updating the parameters. In the following example, we create a trainer
that uses ADAM optimzier.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">trainer</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">collect_params</span><span class="p">(),</span> <span class="s1">&#39;adam&#39;</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">lr</span><span class="p">})</span>
</pre></div>
</div>
<p>We can then write the training loop. During the training, we evaluate on
the validation and testing datasets every epoch, and record the
parameters that give the hightest BLEU score on the validation dataset.
Before performing forward and backward, we first use <code class="docutils literal notranslate"><span class="pre">as_in_context</span></code>
function to copy the mini-batch to GPU. The statement
<code class="docutils literal notranslate"><span class="pre">with</span> <span class="pre">mx.autograd.record()</span></code> tells Gluon backend to compute the
gradients for the part inside the block.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">best_valid_bleu</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">for</span> <span class="n">epoch_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
    <span class="n">log_avg_loss</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">log_avg_gnorm</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">log_wc</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">log_start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">batch_id</span><span class="p">,</span> <span class="p">(</span><span class="n">src_seq</span><span class="p">,</span> <span class="n">tgt_seq</span><span class="p">,</span> <span class="n">src_valid_length</span><span class="p">,</span> <span class="n">tgt_valid_length</span><span class="p">)</span>\
            <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_data_loader</span><span class="p">):</span>
        <span class="c1"># logging.info(src_seq.context) Context suddenly becomes GPU.</span>
        <span class="n">src_seq</span> <span class="o">=</span> <span class="n">src_seq</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="n">tgt_seq</span> <span class="o">=</span> <span class="n">tgt_seq</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="n">src_valid_length</span> <span class="o">=</span> <span class="n">src_valid_length</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="n">tgt_valid_length</span> <span class="o">=</span> <span class="n">tgt_valid_length</span><span class="o">.</span><span class="n">as_in_context</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span>
        <span class="k">with</span> <span class="n">mx</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">record</span><span class="p">():</span>
            <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">src_seq</span><span class="p">,</span> <span class="n">tgt_seq</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">src_valid_length</span><span class="p">,</span> <span class="n">tgt_valid_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_function</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">tgt_seq</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="n">tgt_valid_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">*</span> <span class="p">(</span><span class="n">tgt_seq</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">tgt_valid_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
            <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
        <span class="n">grads</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">ctx</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">collect_params</span><span class="p">()</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span>
        <span class="n">gnorm</span> <span class="o">=</span> <span class="n">gluon</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_global_norm</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">clip</span><span class="p">)</span>
        <span class="n">trainer</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">src_wc</span> <span class="o">=</span> <span class="n">src_valid_length</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
        <span class="n">tgt_wc</span> <span class="o">=</span> <span class="p">(</span><span class="n">tgt_valid_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
        <span class="n">step_loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">asscalar</span><span class="p">()</span>
        <span class="n">log_avg_loss</span> <span class="o">+=</span> <span class="n">step_loss</span>
        <span class="n">log_avg_gnorm</span> <span class="o">+=</span> <span class="n">gnorm</span>
        <span class="n">log_wc</span> <span class="o">+=</span> <span class="n">src_wc</span> <span class="o">+</span> <span class="n">tgt_wc</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">batch_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">wps</span> <span class="o">=</span> <span class="n">log_wc</span> <span class="o">/</span> <span class="p">(</span><span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">log_start_time</span><span class="p">)</span>
            <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;[Epoch </span><span class="si">{}</span><span class="s1"> Batch </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1">] loss=</span><span class="si">{:.4f}</span><span class="s1">, ppl=</span><span class="si">{:.4f}</span><span class="s1">, gnorm=</span><span class="si">{:.4f}</span><span class="s1">, &#39;</span>
                         <span class="s1">&#39;throughput=</span><span class="si">{:.2f}</span><span class="s1">K wps, wc=</span><span class="si">{:.2f}</span><span class="s1">K&#39;</span>
                         <span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch_id</span><span class="p">,</span> <span class="n">batch_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_data_loader</span><span class="p">),</span>
                                 <span class="n">log_avg_loss</span> <span class="o">/</span> <span class="n">log_interval</span><span class="p">,</span>
                                 <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_avg_loss</span> <span class="o">/</span> <span class="n">log_interval</span><span class="p">),</span>
                                 <span class="n">log_avg_gnorm</span> <span class="o">/</span> <span class="n">log_interval</span><span class="p">,</span>
                                 <span class="n">wps</span> <span class="o">/</span> <span class="mi">1000</span><span class="p">,</span> <span class="n">log_wc</span> <span class="o">/</span> <span class="mi">1000</span><span class="p">))</span>
            <span class="n">log_start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
            <span class="n">log_avg_loss</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="n">log_avg_gnorm</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="n">log_wc</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">valid_loss</span><span class="p">,</span> <span class="n">valid_translation_out</span> <span class="o">=</span> <span class="n">evaluate</span><span class="p">(</span><span class="n">val_data_loader</span><span class="p">)</span>
    <span class="n">valid_bleu_score</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">nmt</span><span class="o">.</span><span class="n">bleu</span><span class="o">.</span><span class="n">compute_bleu</span><span class="p">([</span><span class="n">val_tgt_sentences</span><span class="p">],</span> <span class="n">valid_translation_out</span><span class="p">)</span>
    <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;[Epoch </span><span class="si">{}</span><span class="s1">] valid Loss=</span><span class="si">{:.4f}</span><span class="s1">, valid ppl=</span><span class="si">{:.4f}</span><span class="s1">, valid bleu=</span><span class="si">{:.2f}</span><span class="s1">&#39;</span>
                 <span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch_id</span><span class="p">,</span> <span class="n">valid_loss</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">valid_loss</span><span class="p">),</span> <span class="n">valid_bleu_score</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span>
    <span class="n">test_loss</span><span class="p">,</span> <span class="n">test_translation_out</span> <span class="o">=</span> <span class="n">evaluate</span><span class="p">(</span><span class="n">test_data_loader</span><span class="p">)</span>
    <span class="n">test_bleu_score</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">nmt</span><span class="o">.</span><span class="n">bleu</span><span class="o">.</span><span class="n">compute_bleu</span><span class="p">([</span><span class="n">test_tgt_sentences</span><span class="p">],</span> <span class="n">test_translation_out</span><span class="p">)</span>
    <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;[Epoch </span><span class="si">{}</span><span class="s1">] test Loss=</span><span class="si">{:.4f}</span><span class="s1">, test ppl=</span><span class="si">{:.4f}</span><span class="s1">, test bleu=</span><span class="si">{:.2f}</span><span class="s1">&#39;</span>
                 <span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch_id</span><span class="p">,</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">test_loss</span><span class="p">),</span> <span class="n">test_bleu_score</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span>
    <span class="n">write_sentences</span><span class="p">(</span><span class="n">valid_translation_out</span><span class="p">,</span>
                    <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_dir</span><span class="p">,</span> <span class="s1">&#39;epoch</span><span class="si">{:d}</span><span class="s1">_valid_out.txt&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch_id</span><span class="p">))</span>
    <span class="n">write_sentences</span><span class="p">(</span><span class="n">test_translation_out</span><span class="p">,</span>
                    <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_dir</span><span class="p">,</span> <span class="s1">&#39;epoch</span><span class="si">{:d}</span><span class="s1">_test_out.txt&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch_id</span><span class="p">))</span>
    <span class="k">if</span> <span class="n">valid_bleu_score</span> <span class="o">&gt;</span> <span class="n">best_valid_bleu</span><span class="p">:</span>
        <span class="n">best_valid_bleu</span> <span class="o">=</span> <span class="n">valid_bleu_score</span>
        <span class="n">save_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_dir</span><span class="p">,</span> <span class="s1">&#39;valid_best.params&#39;</span><span class="p">)</span>
        <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Save best parameters to </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">save_path</span><span class="p">))</span>
        <span class="n">model</span><span class="o">.</span><span class="n">save_parameters</span><span class="p">(</span><span class="n">save_path</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">epoch_id</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&gt;=</span> <span class="p">(</span><span class="n">epochs</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">//</span> <span class="mi">3</span><span class="p">:</span>
        <span class="n">new_lr</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">*</span> <span class="n">lr_update_factor</span>
        <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Learning rate change to </span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">new_lr</span><span class="p">))</span>
        <span class="n">trainer</span><span class="o">.</span><span class="n">set_learning_rate</span><span class="p">(</span><span class="n">new_lr</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="section" id="summary">
<h2>Summary<a class="headerlink" href="#summary" title="Permalink to this headline">¶</a></h2>
<p>In this notebook, we have shown how to train a GNMT model on IWSLT 2015
English-Vietnamese using Gluon NLP toolkit. The complete training script
can be found
<a class="reference external" href="https://github.com/dmlc/gluon-nlp/blob/master/scripts/nmt/train_gnmt.py">here</a>.
The command to reproduce the result can be seen in the <a class="reference external" href="http://gluon-nlp.mxnet.io/scripts/index.html#machine-translation">nmt scripts
page</a>.</p>
</div>
</div>


        <hr class="feedback-hr-top" />
<div class="feedback-container">
    <div class="feedback-question">Did this page help you?</div>
    <div class="feedback-answer-container">
        <div class="feedback-answer yes-link" data-response="yes">Yes</div>
        <div class="feedback-answer no-link" data-response="no">No</div>
    </div>
    <div class="feedback-thank-you">Thanks for your feedback!</div>
</div>
<hr class="feedback-hr-bottom" />
        </div>
        <div class="side-doc-outline">
            <div class="side-doc-outline--content"> 
<div class="localtoc">
    <p class="caption">
      <span class="caption-text">Table Of Contents</span>
    </p>
    <ul>
<li><a class="reference internal" href="#">Google Neural Machine Translation</a><ul>
<li><a class="reference internal" href="#load-mxnet-and-gluon">Load MXNET and Gluon</a></li>
<li><a class="reference internal" href="#hyper-parameters">Hyper-parameters</a></li>
<li><a class="reference internal" href="#load-and-preprocess-dataset">Load and Preprocess Dataset</a></li>
<li><a class="reference internal" href="#create-sampler-and-dataloader">Create Sampler and DataLoader</a></li>
<li><a class="reference internal" href="#build-gnmt-model">Build GNMT Model</a></li>
<li><a class="reference internal" href="#training-epochs">Training Epochs</a></li>
<li><a class="reference internal" href="#summary">Summary</a></li>
</ul>
</li>
</ul>

</div>
            </div>
        </div>                    

      <div class="clearer"></div>
    </div><div class="pagenation">
     <a id="button-prev" href="index.html" class="mdl-button mdl-js-button mdl-js-ripple-effect mdl-button--colored" role="botton" accesskey="P">
         <i class="pagenation-arrow-L fas fa-arrow-left fa-lg"></i>
         <div class="pagenation-text">
            <span class="pagenation-direction">Previous</span>
            <div>Text Tutorials</div>
         </div>
     </a>
     <a id="button-next" href="transformer.html" class="mdl-button mdl-js-button mdl-js-ripple-effect mdl-button--colored" role="botton" accesskey="N">
         <i class="pagenation-arrow-R fas fa-arrow-right fa-lg"></i>
        <div class="pagenation-text">
            <span class="pagenation-direction">Next</span>
            <div>Machine Translation with Transformer</div>
        </div>
     </a>
  </div>
            <footer class="site-footer h-card">
    <div class="wrapper">
        <div class="row">
            <div class="col-4">
                <h4 class="footer-category-title">Resources</h4>
                <ul class="contact-list">
                    <li><a class="u-email" href="mailto:dev@mxnet.apache.org">Dev list</a></li>
                    <li><a class="u-email" href="mailto:user@mxnet.apache.org">User mailing list</a></li>
                    <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li>
                    <li><a href="https://issues.apache.org/jira/projects/MXNET/issues">Jira Tracker</a></li>
                    <li><a href="https://github.com/apache/incubator-mxnet/labels/Roadmap">Github Roadmap</a></li>
                    <li><a href="https://discuss.mxnet.io">MXNet Discuss forum</a></li>
                    <li><a href="/community/contribute">Contribute To MXNet</a></li>

                </ul>
            </div>

            <div class="col-4"><ul class="social-media-list"><li><a href="https://github.com/apache/incubator-mxnet"><svg class="svg-icon"><use xlink:href="../../../../_static/minima-social-icons.svg#github"></use></svg> <span class="username">apache/incubator-mxnet</span></a></li><li><a href="https://www.twitter.com/apachemxnet"><svg class="svg-icon"><use xlink:href="../../../../_static/minima-social-icons.svg#twitter"></use></svg> <span class="username">apachemxnet</span></a></li><li><a href="https://youtube.com/apachemxnet"><svg class="svg-icon"><use xlink:href="../../../../_static/minima-social-icons.svg#youtube"></use></svg> <span class="username">apachemxnet</span></a></li></ul>
</div>

            <div class="col-4 footer-text">
                <p>A flexible and efficient library for deep learning.</p>
            </div>
        </div>
    </div>
</footer>

<footer class="site-footer2">
    <div class="wrapper">
        <div class="row">
            <div class="col-3">
                <img src="../../../../_static/apache_incubator_logo.png" class="footer-logo col-2">
            </div>
            <div class="footer-bottom-warning col-9">
                <p>Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <span style="font-weight:bold">sponsored by the <i>Apache Incubator</i></span>. Incubation is required
                    of all newly accepted projects until a further review indicates that the infrastructure,
                    communications, and decision making process have stabilized in a manner consistent with other
                    successful ASF projects. While incubation status is not necessarily a reflection of the completeness
                    or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF.
                </p><p>"Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache
                    feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the
                    Apache Software Foundation."</p>
            </div>
        </div>
    </div>
</footer>
        
  </body>
</html>