| import os |
| import mxnet as mx |
| |
| def get_mnist(data_dir): |
| if not os.path.isdir(data_dir): |
| os.system("mkdir " + data_dir) |
| os.chdir(data_dir) |
| if (not os.path.exists('train-images-idx3-ubyte')) or \ |
| (not os.path.exists('train-labels-idx1-ubyte')) or \ |
| (not os.path.exists('t10k-images-idx3-ubyte')) or \ |
| (not os.path.exists('t10k-labels-idx1-ubyte')): |
| import urllib, zipfile |
| zippath = os.path.join(os.getcwd(), "mnist.zip") |
| urllib.urlretrieve("http://data.mxnet.io/mxnet/data/mnist.zip", zippath) |
| zf = zipfile.ZipFile(zippath, "r") |
| zf.extractall() |
| zf.close() |
| os.remove(zippath) |
| os.chdir("..") |
| |
| def get_cifar10(data_dir): |
| if not os.path.isdir(data_dir): |
| os.system("mkdir " + data_dir) |
| cwd = os.path.abspath(os.getcwd()) |
| os.chdir(data_dir) |
| if (not os.path.exists('train.rec')) or \ |
| (not os.path.exists('test.rec')) : |
| import urllib, zipfile, glob |
| dirname = os.getcwd() |
| zippath = os.path.join(dirname, "cifar10.zip") |
| urllib.urlretrieve("http://data.mxnet.io/mxnet/data/cifar10.zip", zippath) |
| zf = zipfile.ZipFile(zippath, "r") |
| zf.extractall() |
| zf.close() |
| os.remove(zippath) |
| for f in glob.glob(os.path.join(dirname, "cifar", "*")): |
| name = f.split(os.path.sep)[-1] |
| os.rename(f, os.path.join(dirname, name)) |
| os.rmdir(os.path.join(dirname, "cifar")) |
| os.chdir(cwd) |
| |
| # data |
| def get_cifar10_iterator(args, kv): |
| data_shape = (3, 28, 28) |
| data_dir = args.data_dir |
| if os.name == "nt": |
| data_dir = data_dir[:-1] + "\\" |
| if '://' not in args.data_dir: |
| get_cifar10(data_dir) |
| |
| train = mx.io.ImageRecordIter( |
| path_imgrec = os.path.join(data_dir, "train.rec"), |
| mean_img = os.path.join(data_dir, "mean.bin"), |
| data_shape = data_shape, |
| batch_size = args.batch_size, |
| rand_crop = True, |
| rand_mirror = True, |
| num_parts = kv.num_workers, |
| part_index = kv.rank) |
| |
| val = mx.io.ImageRecordIter( |
| path_imgrec = os.path.join(data_dir, "test.rec"), |
| mean_img = os.path.join(data_dir, "mean.bin"), |
| rand_crop = False, |
| rand_mirror = False, |
| data_shape = data_shape, |
| batch_size = args.batch_size, |
| num_parts = kv.num_workers, |
| part_index = kv.rank) |
| |
| return (train, val) |