blob: 2a585ead385b0c22c797db093d8f35acafebd32f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import mxnet as mx
import zipfile
def get_mnist(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists('train-images-idx3-ubyte')) or \
(not os.path.exists('train-labels-idx1-ubyte')) or \
(not os.path.exists('t10k-images-idx3-ubyte')) or \
(not os.path.exists('t10k-labels-idx1-ubyte')):
import urllib, zipfile
zippath = os.path.join(os.getcwd(), "mnist.zip")
mx.test_utils.download("http://data.mxnet.io/mxnet/data/mnist.zip", zippath)
zf = zipfile.ZipFile(zippath, "r")
zf.extractall()
zf.close()
os.remove(zippath)
os.chdir("..")
def get_cifar10(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
cwd = os.path.abspath(os.getcwd())
os.chdir(data_dir)
if (not os.path.exists('train.rec')) or \
(not os.path.exists('test.rec')) :
import urllib, zipfile, glob
dirname = os.getcwd()
zippath = os.path.join(dirname, "cifar10.zip")
mx.test_utils.download("http://data.mxnet.io/mxnet/data/cifar10.zip", zippath)
zf = zipfile.ZipFile(zippath, "r")
zf.extractall()
zf.close()
os.remove(zippath)
for f in glob.glob(os.path.join(dirname, "cifar", "*")):
name = f.split(os.path.sep)[-1]
os.rename(f, os.path.join(dirname, name))
os.rmdir(os.path.join(dirname, "cifar"))
os.chdir(cwd)
def get_cifar10_iterator(args, kv):
data_shape = (3, 28, 28)
data_dir = args.data_dir
if os.name == "nt":
data_dir = data_dir[:-1] + "\\"
if '://' not in args.data_dir:
get_cifar10(data_dir)
train = mx.io.ImageRecordIter(
path_imgrec = os.path.join(data_dir, "train.rec"),
mean_img = os.path.join(data_dir, "mean.bin"),
data_shape = data_shape,
batch_size = args.batch_size,
rand_crop = True,
rand_mirror = True,
num_parts = kv.num_workers,
part_index = kv.rank)
val = mx.io.ImageRecordIter(
path_imgrec = os.path.join(data_dir, "test.rec"),
mean_img = os.path.join(data_dir, "mean.bin"),
rand_crop = False,
rand_mirror = False,
data_shape = data_shape,
batch_size = args.batch_size,
num_parts = kv.num_workers,
part_index = kv.rank)
return (train, val)