blob: 90b01e0144cc0f389e2051a5db898fa623c33540 [file] [log] [blame]
from __future__ import print_function
import numpy
import os
import ssl
def load_mnist(training_num=50000):
data_path = os.path.join(os.path.dirname(os.path.realpath('__file__')), 'mnist.npz')
if not os.path.isfile(data_path):
from six.moves import urllib
origin = (
'https://github.com/sxjscience/mxnet/raw/master/example/bayesian-methods/mnist.npz'
)
print('Downloading data from %s to %s' % (origin, data_path))
context = ssl._create_unverified_context()
urllib.request.urlretrieve(origin, data_path, context=context)
print('Done!')
dat = numpy.load(data_path)
X = (dat['X'][:training_num] / 126.0).astype('float32')
Y = dat['Y'][:training_num]
X_test = (dat['X_test'] / 126.0).astype('float32')
Y_test = dat['Y_test']
Y = Y.reshape((Y.shape[0],))
Y_test = Y_test.reshape((Y_test.shape[0],))
return X, Y, X_test, Y_test
def load_toy():
training_data = numpy.loadtxt('toy_data_train.txt')
testing_data = numpy.loadtxt('toy_data_test_whole.txt')
X = training_data[:, 0].reshape((training_data.shape[0], 1))
Y = training_data[:, 1].reshape((training_data.shape[0], 1))
X_test = testing_data[:, 0].reshape((testing_data.shape[0], 1))
Y_test = testing_data[:, 1].reshape((testing_data.shape[0], 1))
return X, Y, X_test, Y_test
def load_synthetic(theta1, theta2, sigmax, num=20):
flag = numpy.random.randint(0, 2, (num,))
X = flag * numpy.random.normal(theta1, sigmax, (num,)) \
+ (1.0 - flag) * numpy.random.normal(theta1 + theta2, sigmax, (num,))
return X