blob: 5fe61b1850417006c64b36d93b1e46f66495b616 [file] [log] [blame]
# pylint: skip-file
import mxnet as mx
import numpy as np
import os, gzip
import pickle as pickle
import time
import sys
from common import get_data
def test_MNISTIter():
# prepare data
get_data.GetMNIST_ubyte()
batch_size = 100
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
data_shape=(784,),
batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10)
# test_loop
nbatch = 60000 / batch_size
batch_count = 0
for batch in train_dataiter:
batch_count += 1
assert(nbatch == batch_count)
# test_reset
train_dataiter.reset()
train_dataiter.iter_next()
label_0 = train_dataiter.getlabel().asnumpy().flatten()
train_dataiter.iter_next()
train_dataiter.iter_next()
train_dataiter.iter_next()
train_dataiter.iter_next()
train_dataiter.reset()
train_dataiter.iter_next()
label_1 = train_dataiter.getlabel().asnumpy().flatten()
assert(sum(label_0 - label_1) == 0)
def test_Cifar10Rec():
# skip-this test for saving time
return
get_data.GetCifar10()
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
rand_crop=False,
and_mirror=False,
shuffle=False,
data_shape=(3,28,28),
batch_size=100,
preprocess_threads=4,
prefetch_buffer=1)
labelcount = [0 for i in range(10)]
batchcount = 0
for batch in dataiter:
npdata = batch.data[0].asnumpy().flatten().sum()
sys.stdout.flush()
batchcount += 1
nplabel = batch.label[0].asnumpy()
for i in range(nplabel.shape[0]):
labelcount[int(nplabel[i])] += 1
for i in range(10):
assert(labelcount[i] == 5000)
def test_NDArrayIter():
datas = np.ones([1000, 2, 2])
labels = np.ones([1000, 1])
for i in range(1000):
datas[i] = i / 100
labels[i] = i / 100
dataiter = mx.io.NDArrayIter(datas, labels, 128, True, last_batch_handle='pad')
batchidx = 0
for batch in dataiter:
batchidx += 1
assert(batchidx == 8)
dataiter = mx.io.NDArrayIter(datas, labels, 128, False, last_batch_handle='pad')
batchidx = 0
labelcount = [0 for i in range(10)]
for batch in dataiter:
label = batch.label[0].asnumpy().flatten()
assert((batch.data[0].asnumpy()[:,0,0] == label).all())
for i in range(label.shape[0]):
labelcount[int(label[i])] += 1
for i in range(10):
if i == 0:
assert(labelcount[i] == 124)
else:
assert(labelcount[i] == 100)
if __name__ == "__main__":
test_NDArrayIter()
test_MNISTIter()
test_Cifar10Rec()