In this example, we will demonstrate how you can implement a Variational Auto-encoder(VAE) with Gluon.probability and MXNet's latest NumPy API.
import numpy as np import mxnet as mx from mxnet import autograd, gluon, np, npx from mxnet.gluon import nn import mxnet.gluon.probability as mgp import matplotlib.pyplot as plt # Switch numpy-compatible semantics on. npx.set_np() # Set context for model context, here we choose to use GPU. model_ctx = mx.gpu(0)
We will use MNIST here for simplicity purpose.
def load_data(batch_size): mnist_train = gluon.data.vision.MNIST(train=True) mnist_test = gluon.data.vision.MNIST(train=False) num_worker = 4 transformer = gluon.data.vision.transforms.ToTensor() return (gluon.data.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_worker), gluon.data.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_worker))
class VAE(gluon.HybridBlock): def __init__(self, n_hidden=256, n_latent=2, n_layers=1, n_output=784, act_type='relu', **kwargs): r""" n_hidden : number of hidden units in each layer n_latent : dimension of the latent space n_layers : number of layers in the encoder and decoder network n_output : dimension of the observed data """ self.soft_zero = 1e-10 self.n_latent = n_latent self.output = None self.mu = None super(VAE, self).__init__(**kwargs) self.encoder = nn.HybridSequential() for _ in range(n_layers): self.encoder.add(nn.Dense(n_hidden, activation=act_type)) self.encoder.add(nn.Dense(n_latent*2, activation=None)) self.decoder = nn.HybridSequential() for _ in range(n_layers): self.decoder.add(nn.Dense(n_hidden, activation=act_type)) self.decoder.add(nn.Dense(n_output, activation='sigmoid')) def encode(self, x): r""" Given a batch of x, return the encoder's output """ # [loc_1, ..., loc_n, log(scale_1), ..., log(scale_n)] h = self.encoder(x) # Extract loc and log_scale from the encoder output. loc_scale = np.split(h, 2, 1) loc = loc_scale[0] log_scale = loc_scale[1] # Convert log_scale back to scale. scale = np.exp(log_scale) # Return a Normal object. return mgp.Normal(loc, scale) def decode(self, z): r""" Given a batch of samples from z, return the decoder's output """ return self.decoder(z) def forward(self, x): r""" Given a batch of data x, return the negative of Evidence Lower-bound, i.e. an objective to minimize. """ # prior p(z) pz = mgp.Normal(0, 1) # posterior q(z|x) qz_x = self.encode(x) # Sampling operation qz_x.sample() is automatically reparameterized. z = qz_x.sample() # Reconstruction result y = self.decode(z) # Gluon.probability can help you calculate the analytical kl-divergence # between two distribution objects. KL = mgp.kl_divergence(qz_x, pz).sum(1) # We assume p(x|z) ~ Bernoulli, therefore we compute the reconstruction # loss with binary cross entropy. logloss = np.sum(x * np.log(y + self.soft_zero) + (1 - x) * np.log(1 - y + self.soft_zero), axis=1) loss = -logloss + KL return loss
def train(net, n_epoch, print_period, train_iter, test_iter): net.initialize(mx.init.Xavier(), ctx=model_ctx) net.hybridize() trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': .001}) training_loss = [] validation_loss = [] for epoch in range(n_epoch): epoch_loss = 0 epoch_val_loss = 0 n_batch_train = 0 for batch in train_iter: n_batch_train += 1 data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28) with autograd.record(): loss = net(data) loss.backward() trainer.step(data.shape[0]) epoch_loss += np.mean(loss) n_batch_val = 0 for batch in test_iter: n_batch_val += 1 data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28) loss = net(data) epoch_val_loss += np.mean(loss) epoch_loss /= n_batch_train epoch_val_loss /= n_batch_val training_loss.append(epoch_loss) validation_loss.append(epoch_val_loss) if epoch % max(print_period, 1) == 0: print('Epoch{}, Training loss {:.2f}, Validation loss {:.2f}'.format( epoch, float(epoch_loss), float(epoch_val_loss)))
n_hidden = 128 n_latent = 40 n_layers = 3 n_output = 784 batch_size = 128 model_prefix = 'vae_gluon_{}d{}l{}h.params'.format( n_latent, n_layers, n_hidden) net = VAE(n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, n_output=n_output) net.hybridize() n_epoch = 50 print_period = n_epoch // 10 train_set, test_set = load_data(batch_size) train(net, n_epoch, print_period, train_set, test_set)
To verify the effictiveness of our model, we first take a look at how well our model can reconstruct the data.
# Grab a batch from the test set qz_x = None for batch in test_set: data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28) qz_x = net.encode(data) break
num_samples = 4 fig, axes = plt.subplots(nrows=num_samples, ncols=2, figsize=(4, 6), subplot_kw={'xticks': [], 'yticks': []}) axes[0, 0].set_title('Original image') axes[0, 1].set_title('reconstruction') for i in range(num_samples): axes[i, 0].imshow(data[i].squeeze().reshape(28, 28).asnumpy(), cmap='gray') axes[i, 1].imshow(net.decode(qz_x.sample())[i].reshape(28, 28).asnumpy(), cmap='gray')
One of the most important difference between Variational Auto-encoder and Auto-encoder is VAE's capabilities of generating new samples.
To achieve that, one simply needs to feed a random sample from $p(z) \sim \mathcal{N}(0,1)$ to the decoder network.
def plot_samples(samples, h=5, w=10): fig, axes = plt.subplots(nrows=h, ncols=w, figsize=(int(1.4 * w), int(1.4 * h)), subplot_kw={'xticks': [], 'yticks': []}) for i, ax in enumerate(axes.flatten()): ax.imshow(samples[i], cmap='gray')
n_samples = 20 noise = np.random.randn(n_samples, n_latent).as_in_context(model_ctx) dec_output = net.decode(noise).reshape(-1, 28, 28).asnumpy() plot_samples(dec_output, 4, 5)