blob: 77e9b945cb14e52e82d40c0c21499fe614e8bfe3 [file] [log] [blame]
from __future__ import print_function
import mxnet as mx
import mxnet.ndarray as nd
import numpy
import logging
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import argparse
from algos import *
from data_loader import *
from utils import *
class CrossEntropySoftmax(mx.operator.NumpyOp):
def __init__(self):
super(CrossEntropySoftmax, self).__init__(False)
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = in_shape[0]
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape]
def forward(self, in_data, out_data):
x = in_data[0]
y = out_data[0]
y[:] = numpy.exp(x - x.max(axis=1).reshape((x.shape[0], 1))).astype('float32')
y /= y.sum(axis=1).reshape((x.shape[0], 1))
def backward(self, out_grad, in_data, out_data, in_grad):
l = in_data[1]
y = out_data[0]
dx = in_grad[0]
dx[:] = (y - l)
class LogSoftmax(mx.operator.NumpyOp):
def __init__(self):
super(LogSoftmax, self).__init__(False)
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = in_shape[0]
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape]
def forward(self, in_data, out_data):
x = in_data[0]
y = out_data[0]
y[:] = (x - x.max(axis=1, keepdims=True)).astype('float32')
y -= numpy.log(numpy.exp(y).sum(axis=1, keepdims=True)).astype('float32')
# y[:] = numpy.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))
# y /= y.sum(axis=1).reshape((x.shape[0], 1))
def backward(self, out_grad, in_data, out_data, in_grad):
l = in_data[1]
y = out_data[0]
dx = in_grad[0]
dx[:] = (numpy.exp(y) - l).astype('float32')
def classification_student_grad(student_outputs, teacher_pred):
return [student_outputs[0] - teacher_pred]
def regression_student_grad(student_outputs, teacher_pred, teacher_noise_precision):
student_mean = student_outputs[0]
student_var = student_outputs[1]
grad_mean = nd.exp(-student_var) * (student_mean - teacher_pred)
grad_var = (1 - nd.exp(-student_var) * (nd.square(student_mean - teacher_pred)
+ 1.0 / teacher_noise_precision)) / 2
return [grad_mean, grad_var]
def get_mnist_sym(output_op=None, num_hidden=400):
net = mx.symbol.Variable('data')
net = mx.symbol.FullyConnected(data=net, name='mnist_fc1', num_hidden=num_hidden)
net = mx.symbol.Activation(data=net, name='mnist_relu1', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='mnist_fc2', num_hidden=num_hidden)
net = mx.symbol.Activation(data=net, name='mnist_relu2', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='mnist_fc3', num_hidden=10)
if output_op is None:
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
else:
net = output_op(data=net, name='softmax')
return net
def synthetic_grad(X, theta, sigma1, sigma2, sigmax, rescale_grad=1.0, grad=None):
if grad is None:
grad = nd.empty(theta.shape, theta.context)
theta1 = theta.asnumpy()[0]
theta2 = theta.asnumpy()[1]
v1 = sigma1 ** 2
v2 = sigma2 ** 2
vx = sigmax ** 2
denominator = numpy.exp(-(X - theta1) ** 2 / (2 * vx)) + numpy.exp(
-(X - theta1 - theta2) ** 2 / (2 * vx))
grad_npy = numpy.zeros(theta.shape)
grad_npy[0] = -rescale_grad * ((numpy.exp(-(X - theta1) ** 2 / (2 * vx)) * (X - theta1) / vx
+ numpy.exp(-(X - theta1 - theta2) ** 2 / (2 * vx)) * (
X - theta1 - theta2) / vx) / denominator).sum() \
+ theta1 / v1
grad_npy[1] = -rescale_grad * ((numpy.exp(-(X - theta1 - theta2) ** 2 / (2 * vx)) * (
X - theta1 - theta2) / vx) / denominator).sum() \
+ theta2 / v2
grad[:] = grad_npy
return grad
def get_toy_sym(teacher=True, teacher_noise_precision=None):
if teacher:
net = mx.symbol.Variable('data')
net = mx.symbol.FullyConnected(data=net, name='teacher_fc1', num_hidden=100)
net = mx.symbol.Activation(data=net, name='teacher_relu1', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='teacher_fc2', num_hidden=1)
net = mx.symbol.LinearRegressionOutput(data=net, name='teacher_output',
grad_scale=teacher_noise_precision)
else:
net = mx.symbol.Variable('data')
net = mx.symbol.FullyConnected(data=net, name='student_fc1', num_hidden=100)
net = mx.symbol.Activation(data=net, name='student_relu1', act_type="relu")
student_mean = mx.symbol.FullyConnected(data=net, name='student_mean', num_hidden=1)
student_var = mx.symbol.FullyConnected(data=net, name='student_var', num_hidden=1)
net = mx.symbol.Group([student_mean, student_var])
return net
def dev():
return mx.gpu()
def run_mnist_SGD(training_num=50000):
X, Y, X_test, Y_test = load_mnist(training_num)
minibatch_size = 100
net = get_mnist_sym()
data_shape = (minibatch_size,) + X.shape[1::]
data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'softmax_label': nd.zeros((minibatch_size,), ctx=dev())}
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
exe, exe_params, _ = SGD(sym=net, dev=dev(), data_inputs=data_inputs, X=X, Y=Y,
X_test=X_test, Y_test=Y_test,
total_iter_num=1000000,
initializer=initializer,
lr=5E-6, prior_precision=1.0, minibatch_size=100)
def run_mnist_SGLD(training_num=50000):
X, Y, X_test, Y_test = load_mnist(training_num)
minibatch_size = 100
net = get_mnist_sym()
data_shape = (minibatch_size,) + X.shape[1::]
data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'softmax_label': nd.zeros((minibatch_size,), ctx=dev())}
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
exe, sample_pool = SGLD(sym=net, dev=dev(), data_inputs=data_inputs, X=X, Y=Y,
X_test=X_test, Y_test=Y_test,
total_iter_num=1000000,
initializer=initializer,
learning_rate=4E-6, prior_precision=1.0, minibatch_size=100,
thin_interval=100, burn_in_iter_num=1000)
def run_mnist_DistilledSGLD(training_num=50000):
X, Y, X_test, Y_test = load_mnist(training_num)
minibatch_size = 100
if training_num >= 10000:
num_hidden = 800
total_iter_num = 1000000
teacher_learning_rate = 1E-6
student_learning_rate = 0.0001
teacher_prior = 1
student_prior = 0.1
perturb_deviation = 0.1
else:
num_hidden = 400
total_iter_num = 20000
teacher_learning_rate = 4E-5
student_learning_rate = 0.0001
teacher_prior = 1
student_prior = 0.1
perturb_deviation = 0.001
teacher_net = get_mnist_sym(num_hidden=num_hidden)
logsoftmax = LogSoftmax()
student_net = get_mnist_sym(output_op=logsoftmax, num_hidden=num_hidden)
data_shape = (minibatch_size,) + X.shape[1::]
teacher_data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'softmax_label': nd.zeros((minibatch_size,), ctx=dev())}
student_data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'softmax_label': nd.zeros((minibatch_size, 10), ctx=dev())}
teacher_initializer = BiasXavier(factor_type="in", magnitude=1)
student_initializer = BiasXavier(factor_type="in", magnitude=1)
student_exe, student_params, _ = \
DistilledSGLD(teacher_sym=teacher_net, student_sym=student_net,
teacher_data_inputs=teacher_data_inputs,
student_data_inputs=student_data_inputs,
X=X, Y=Y, X_test=X_test, Y_test=Y_test, total_iter_num=total_iter_num,
student_initializer=student_initializer,
teacher_initializer=teacher_initializer,
student_optimizing_algorithm="adam",
teacher_learning_rate=teacher_learning_rate,
student_learning_rate=student_learning_rate,
teacher_prior_precision=teacher_prior, student_prior_precision=student_prior,
perturb_deviation=perturb_deviation, minibatch_size=100, dev=dev())
def run_toy_SGLD():
X, Y, X_test, Y_test = load_toy()
minibatch_size = 1
teacher_noise_precision = 1.0 / 9.0
net = get_toy_sym(True, teacher_noise_precision)
data_shape = (minibatch_size,) + X.shape[1::]
data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'teacher_output_label': nd.zeros((minibatch_size, 1), ctx=dev())}
initializer = mx.init.Uniform(0.07)
exe, params, _ = \
SGLD(sym=net, data_inputs=data_inputs,
X=X, Y=Y, X_test=X_test, Y_test=Y_test, total_iter_num=50000,
initializer=initializer,
learning_rate=1E-4,
# lr_scheduler=mx.lr_scheduler.FactorScheduler(100000, 0.5),
prior_precision=0.1,
burn_in_iter_num=1000,
thin_interval=10,
task='regression',
minibatch_size=minibatch_size, dev=dev())
def run_toy_DistilledSGLD():
X, Y, X_test, Y_test = load_toy()
minibatch_size = 1
teacher_noise_precision = 1.0
teacher_net = get_toy_sym(True, teacher_noise_precision)
student_net = get_toy_sym(False)
data_shape = (minibatch_size,) + X.shape[1::]
teacher_data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'teacher_output_label': nd.zeros((minibatch_size, 1), ctx=dev())}
student_data_inputs = {'data': nd.zeros(data_shape, ctx=dev())}
# 'softmax_label': nd.zeros((minibatch_size, 10), ctx=dev())}
teacher_initializer = mx.init.Uniform(0.07)
student_initializer = mx.init.Uniform(0.07)
student_grad_f = lambda student_outputs, teacher_pred: \
regression_student_grad(student_outputs, teacher_pred, teacher_noise_precision)
student_exe, student_params, _ = \
DistilledSGLD(teacher_sym=teacher_net, student_sym=student_net,
teacher_data_inputs=teacher_data_inputs,
student_data_inputs=student_data_inputs,
X=X, Y=Y, X_test=X_test, Y_test=Y_test, total_iter_num=80000,
teacher_initializer=teacher_initializer,
student_initializer=student_initializer,
teacher_learning_rate=1E-4, student_learning_rate=0.01,
# teacher_lr_scheduler=mx.lr_scheduler.FactorScheduler(100000, 0.5),
student_lr_scheduler=mx.lr_scheduler.FactorScheduler(8000, 0.8),
student_grad_f=student_grad_f,
teacher_prior_precision=0.1, student_prior_precision=0.001,
perturb_deviation=0.1, minibatch_size=minibatch_size, task='regression',
dev=dev())
def run_toy_HMC():
X, Y, X_test, Y_test = load_toy()
minibatch_size = Y.shape[0]
noise_precision = 1 / 9.0
net = get_toy_sym(True, noise_precision)
data_shape = (minibatch_size,) + X.shape[1::]
data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),
'teacher_output_label': nd.zeros((minibatch_size, 1), ctx=dev())}
initializer = mx.init.Uniform(0.07)
sample_pool = HMC(net, data_inputs=data_inputs, X=X, Y=Y, X_test=X_test, Y_test=Y_test,
sample_num=300000, initializer=initializer, prior_precision=1.0,
learning_rate=1E-3, L=10, dev=dev())
def run_synthetic_SGLD():
theta1 = 0
theta2 = 1
sigma1 = numpy.sqrt(10)
sigma2 = 1
sigmax = numpy.sqrt(2)
X = load_synthetic(theta1=theta1, theta2=theta2, sigmax=sigmax, num=100)
minibatch_size = 1
total_iter_num = 1000000
lr_scheduler = SGLDScheduler(begin_rate=0.01, end_rate=0.0001, total_iter_num=total_iter_num,
factor=0.55)
optimizer = mx.optimizer.create('sgld',
learning_rate=None,
rescale_grad=1.0,
lr_scheduler=lr_scheduler,
wd=0)
updater = mx.optimizer.get_updater(optimizer)
theta = mx.random.normal(0, 1, (2,), mx.cpu())
grad = nd.empty((2,), mx.cpu())
samples = numpy.zeros((2, total_iter_num))
start = time.time()
for i in xrange(total_iter_num):
if (i + 1) % 100000 == 0:
end = time.time()
print("Iter:%d, Time spent: %f" % (i + 1, end - start))
start = time.time()
ind = numpy.random.randint(0, X.shape[0])
synthetic_grad(X[ind], theta, sigma1, sigma2, sigmax, rescale_grad=
X.shape[0] / float(minibatch_size), grad=grad)
updater('theta', grad, theta)
samples[:, i] = theta.asnumpy()
plt.hist2d(samples[0, :], samples[1, :], (200, 200), cmap=plt.cm.jet)
plt.colorbar()
plt.show()
if __name__ == '__main__':
numpy.random.seed(100)
mx.random.seed(100)
parser = argparse.ArgumentParser(
description="Examples in the paper [NIPS2015]Bayesian Dark Knowledge and "
"[ICML2011]Bayesian Learning via Stochastic Gradient Langevin Dynamics")
parser.add_argument("-d", "--dataset", type=int, default=1,
help="Dataset to use. 0 --> TOY, 1 --> MNIST, 2 --> Synthetic Data in "
"the SGLD paper")
parser.add_argument("-l", "--algorithm", type=int, default=2,
help="Type of algorithm to use. 0 --> SGD, 1 --> SGLD, other-->DistilledSGLD")
parser.add_argument("-t", "--training", type=int, default=50000,
help="Number of training samples")
args = parser.parse_args()
training_num = args.training
if args.dataset == 1:
if 0 == args.algorithm:
run_mnist_SGD(training_num)
elif 1 == args.algorithm:
run_mnist_SGLD(training_num)
else:
run_mnist_DistilledSGLD(training_num)
elif args.dataset == 0:
if 1 == args.algorithm:
run_toy_SGLD()
elif 2 == args.algorithm:
run_toy_DistilledSGLD()
elif 3 == args.algorithm:
run_toy_HMC()
else:
run_synthetic_SGLD()