blob: e1a3653bcafc9f062d06ded4707fb8a95a175fa8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import random as pyrnd
import argparse
import numpy as np
import mxnet as mx
from matplotlib import pyplot as plt
import binary_rbm
### Command line arguments
parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST')
parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units')
parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm')
parser.add_argument('--batch-size', type=int, default=80, help='batch size')
parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs')
parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size`
parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent')
parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood')
parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood')
parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood')
parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood')
parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA')
parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU')
parser.add_argument('--device-id', type=int, default=0, help='GPU device id')
parser.set_defaults(cuda=True)
args = parser.parse_args()
print(args)
### Global environment
mx.random.seed(pyrnd.getrandbits(32))
ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu()
### Prepare data
mnist = mx.test_utils.get_mnist() # Each pixel has a value in [0, 1].
mnist_train_data = mnist['train_data']
mnist_test_data = mnist['test_data']
img_height = mnist_train_data.shape[2]
img_width = mnist_train_data.shape[3]
num_visible = img_width * img_height
# The iterators generate arrays with shape (batch_size, num_channel = 1, height = 28, width = 28)
train_iter = mx.io.NDArrayIter(
data={'data': mnist_train_data},
batch_size=args.batch_size,
shuffle=True)
test_iter = mx.io.NDArrayIter(
data={'data': mnist_test_data},
batch_size=args.batch_size,
shuffle=True)
### Define symbols
data = mx.sym.Variable('data') # (batch_size, num_channel = 1, height, width)
flattened_data = mx.sym.flatten(data=data) # (batch_size, num_channel * height * width)
visible_layer_bias = mx.sym.Variable('visible_layer_bias', init=mx.init.Normal(sigma=.01))
hidden_layer_bias = mx.sym.Variable('hidden_layer_bias', init=mx.init.Normal(sigma=.01))
interaction_weight = mx.sym.Variable('interaction_weight', init=mx.init.Normal(sigma=.01))
aux_hidden_layer_sample = mx.sym.Variable('aux_hidden_layer_sample', init=mx.init.Normal(sigma=.01))
aux_hidden_layer_prob_1 = mx.sym.Variable('aux_hidden_layer_prob_1', init=mx.init.Constant(0))
### Train
rbm = mx.sym.Custom(
flattened_data,
visible_layer_bias,
hidden_layer_bias,
interaction_weight,
aux_hidden_layer_sample,
aux_hidden_layer_prob_1,
num_hidden=args.num_hidden,
k=args.k,
for_training=True,
op_type='BinaryRBM',
name='rbm')
model = mx.mod.Module(symbol=rbm, context=ctx, data_names=['data'], label_names=None)
model.bind(data_shapes=train_iter.provide_data)
model.init_params()
model.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': args.learning_rate, 'momentum': args.momentum})
for epoch in range(args.num_epoch):
# Update parameters
train_iter.reset()
for batch in train_iter:
model.forward(batch)
model.backward()
model.update()
mx.nd.waitall()
# Monitor the performace of the model
params = model.get_params()[0]
param_visible_layer_bias = params['visible_layer_bias'].as_in_context(ctx)
param_hidden_layer_bias = params['hidden_layer_bias'].as_in_context(ctx)
param_interaction_weight = params['interaction_weight'].as_in_context(ctx)
test_iter.reset()
test_log_likelihood, _ = binary_rbm.estimate_log_likelihood(
param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_iter, ctx)
train_iter.reset()
train_log_likelihood, _ = binary_rbm.estimate_log_likelihood(
param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_iter, ctx)
print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood))
### Show some samples.
# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample.
# Starting from the real data is just for convenience of implmentation.
# There must be no correlation between the initial states and the resulting samples.
# You can start from random states and run the Gibbs chain for sufficiently long time.
print("Preparing showcase")
showcase_gibbs_sampling_steps = 3000
showcase_num_samples_w = 15
showcase_num_samples_h = 15
showcase_num_samples = showcase_num_samples_w * showcase_num_samples_h
showcase_img_shape = (showcase_num_samples_h * img_height, 2 * showcase_num_samples_w * img_width)
showcase_img_column_shape = (showcase_num_samples_h * img_height, img_width)
params = model.get_params()[0] # We don't need aux states here
showcase_rbm = mx.sym.Custom(
flattened_data,
visible_layer_bias,
hidden_layer_bias,
interaction_weight,
num_hidden=args.num_hidden,
k=showcase_gibbs_sampling_steps,
for_training=False,
op_type='BinaryRBM',
name='showcase_rbm')
showcase_iter = mx.io.NDArrayIter(
data={'data': mnist['train_data']},
batch_size=showcase_num_samples_h,
shuffle=True)
showcase_model = mx.mod.Module(symbol=showcase_rbm, context=ctx, data_names=['data'], label_names=None)
showcase_model.bind(data_shapes=showcase_iter.provide_data, for_training=False)
showcase_model.set_params(params, aux_params=None)
showcase_img = np.zeros(showcase_img_shape)
for sample_batch, i, data_batch in showcase_model.iter_predict(eval_data=showcase_iter, num_batch=showcase_num_samples_w):
# Each pixel is the probability that the unit is 1.
showcase_img[:, i * img_width : (i + 1) * img_width] = data_batch.data[0].reshape(showcase_img_column_shape).asnumpy()
showcase_img[:, (showcase_num_samples_w + i) * img_width : (showcase_num_samples_w + i + 1) * img_width
] = sample_batch[0].reshape(showcase_img_column_shape).asnumpy()
s = plt.imshow(showcase_img, cmap='gray')
plt.axis('off')
plt.axvline(showcase_num_samples_w * img_width, color='y')
plt.show(s)
print("Done")