blob: ecd117d862181d9f0b138dd2ca7141315888e066 [file] [log] [blame]
import os
import numpy as np
from sklearn.datasets import fetch_mldata
def get_mnist():
np.random.seed(1234) # set seed for deterministic ordering
data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
data_path = os.path.join(data_path, '../../data')
mnist = fetch_mldata('MNIST original', data_home=data_path)
p = np.random.permutation(mnist.data.shape[0])
X = mnist.data[p].astype(np.float32)*0.02
Y = mnist.target[p]
return X, Y