blob: 4dfa69cc105046e7d1a78038059a3b1b8f12b2dc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import mxnet as mx
from mxnet.test_utils import *
from mxnet.base import MXNetError
import numpy as np
import os
import gzip
import pickle as pickle
import time
try:
import h5py
except ImportError:
h5py = None
import sys
from common import assertRaises
import unittest
def test_MNISTIter():
# prepare data
get_mnist_ubyte()
batch_size = 100
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
data_shape=(784,),
batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10)
# test_loop
nbatch = 60000 / batch_size
batch_count = 0
for batch in train_dataiter:
batch_count += 1
assert(nbatch == batch_count)
# test_reset
train_dataiter.reset()
train_dataiter.iter_next()
label_0 = train_dataiter.getlabel().asnumpy().flatten()
train_dataiter.iter_next()
train_dataiter.iter_next()
train_dataiter.iter_next()
train_dataiter.iter_next()
train_dataiter.reset()
train_dataiter.iter_next()
label_1 = train_dataiter.getlabel().asnumpy().flatten()
assert(sum(label_0 - label_1) == 0)
def test_Cifar10Rec():
get_cifar10()
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
rand_crop=False,
and_mirror=False,
shuffle=False,
data_shape=(3, 28, 28),
batch_size=100,
preprocess_threads=4,
prefetch_buffer=1)
labelcount = [0 for i in range(10)]
batchcount = 0
for batch in dataiter:
npdata = batch.data[0].asnumpy().flatten().sum()
sys.stdout.flush()
batchcount += 1
nplabel = batch.label[0].asnumpy()
for i in range(nplabel.shape[0]):
labelcount[int(nplabel[i])] += 1
for i in range(10):
assert(labelcount[i] == 5000)
def test_NDArrayIter():
data = np.ones([1000, 2, 2])
label = np.ones([1000, 1])
for i in range(1000):
data[i] = i / 100
label[i] = i / 100
dataiter = mx.io.NDArrayIter(
data, label, 128, True, last_batch_handle='pad')
batchidx = 0
for batch in dataiter:
batchidx += 1
assert(batchidx == 8)
dataiter = mx.io.NDArrayIter(
data, label, 128, False, last_batch_handle='pad')
batchidx = 0
labelcount = [0 for i in range(10)]
for batch in dataiter:
label = batch.label[0].asnumpy().flatten()
assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
for i in range(label.shape[0]):
labelcount[int(label[i])] += 1
for i in range(10):
if i == 0:
assert(labelcount[i] == 124)
else:
assert(labelcount[i] == 100)
def test_NDArrayIter_h5py():
if not h5py:
return
data = np.ones([1000, 2, 2])
label = np.ones([1000, 1])
for i in range(1000):
data[i] = i / 100
label[i] = i / 100
try:
os.remove("ndarraytest.h5")
except OSError:
pass
with h5py.File("ndarraytest.h5") as f:
f.create_dataset("data", data=data)
f.create_dataset("label", data=label)
dataiter = mx.io.NDArrayIter(
f["data"], f["label"], 128, True, last_batch_handle='pad')
batchidx = 0
for batch in dataiter:
batchidx += 1
assert(batchidx == 8)
dataiter = mx.io.NDArrayIter(
f["data"], f["label"], 128, False, last_batch_handle='pad')
labelcount = [0 for i in range(10)]
for batch in dataiter:
label = batch.label[0].asnumpy().flatten()
assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
for i in range(label.shape[0]):
labelcount[int(label[i])] += 1
try:
os.remove("ndarraytest.h5")
except OSError:
pass
for i in range(10):
if i == 0:
assert(labelcount[i] == 124)
else:
assert(labelcount[i] == 100)
def test_NDArrayIter_csr():
# creating toy data
num_rows = rnd.randint(5, 15)
num_cols = rnd.randint(1, 20)
batch_size = rnd.randint(1, num_rows)
shape = (num_rows, num_cols)
csr, _ = rand_sparse_ndarray(shape, 'csr')
dns = csr.asnumpy()
# CSRNDArray or scipy.sparse.csr_matrix with last_batch_handle not equal to 'discard' will throw NotImplementedError
assertRaises(NotImplementedError, mx.io.NDArrayIter,
{'data': csr}, dns, batch_size)
try:
import scipy.sparse as spsp
train_data = spsp.csr_matrix(dns)
assertRaises(NotImplementedError, mx.io.NDArrayIter,
{'data': train_data}, dns, batch_size)
except ImportError:
pass
# CSRNDArray with shuffle
csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size,
shuffle=True, last_batch_handle='discard'))
num_batch = 0
for batch in csr_iter:
num_batch += 1
assert(num_batch == num_rows // batch_size)
# make iterators
csr_iter = iter(mx.io.NDArrayIter(
csr, csr, batch_size, last_batch_handle='discard'))
begin = 0
for batch in csr_iter:
expected = np.zeros((batch_size, num_cols))
end = begin + batch_size
expected[:num_rows - begin] = dns[begin:end]
if end > num_rows:
expected[num_rows - begin:] = dns[0:end - num_rows]
assert_almost_equal(batch.data[0].asnumpy(), expected)
begin += batch_size
def test_LibSVMIter():
def check_libSVMIter_synthetic():
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
with open(data_path, 'w') as fout:
fout.write('1.0 0:0.5 2:1.2\n')
fout.write('-2.0\n')
fout.write('-3.0 0:0.6 1:2.4 2:1.2\n')
fout.write('4 2:-1.2\n')
with open(label_path, 'w') as fout:
fout.write('1.0\n')
fout.write('-2.0 0:0.125\n')
fout.write('-3.0 2:1.2\n')
fout.write('4 1:1.0 2:-1.2\n')
data_dir = os.path.join(cwd, 'data')
data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path,
data_shape=(3, ), label_shape=(3, ), batch_size=3)
first = mx.nd.array([[0.5, 0., 1.2], [0., 0., 0.], [0.6, 2.4, 1.2]])
second = mx.nd.array([[0., 0., -1.2], [0.5, 0., 1.2], [0., 0., 0.]])
i = 0
for batch in iter(data_train):
expected = first.asnumpy() if i == 0 else second.asnumpy()
data = data_train.getdata()
data.check_format(True)
assert_almost_equal(data.asnumpy(), expected)
i += 1
def check_libSVMIter_news_data():
news_metadata = {
'name': 'news20.t',
'origin_name': 'news20.t.bz2',
'url': "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/news20.t.bz2",
'feature_dim': 62060 + 1,
'num_classes': 20,
'num_examples': 3993,
}
batch_size = 33
num_examples = news_metadata['num_examples']
data_dir = os.path.join(os.getcwd(), 'data')
get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'],
news_metadata['origin_name'])
path = os.path.join(data_dir, news_metadata['name'])
data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],),
batch_size=batch_size)
for epoch in range(2):
num_batches = 0
for batch in data_train:
# check the range of labels
data = batch.data[0]
label = batch.label[0]
data.check_format(True)
assert(np.sum(label.asnumpy() > 20) == 0)
assert(np.sum(label.asnumpy() <= 0) == 0)
num_batches += 1
expected_num_batches = num_examples / batch_size
assert(num_batches == int(expected_num_batches)), num_batches
data_train.reset()
def check_libSVMIter_exception():
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
with open(data_path, 'w') as fout:
fout.write('1.0 0:0.5 2:1.2\n')
fout.write('-2.0\n')
# Below line has a neg indice. Should throw an exception
fout.write('-3.0 -1:0.6 1:2.4 2:1.2\n')
fout.write('4 2:-1.2\n')
with open(label_path, 'w') as fout:
fout.write('1.0\n')
fout.write('-2.0 0:0.125\n')
fout.write('-3.0 2:1.2\n')
fout.write('4 1:1.0 2:-1.2\n')
data_dir = os.path.join(cwd, 'data')
data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path,
data_shape=(3, ), label_shape=(3, ), batch_size=3)
for batch in iter(data_train):
data_train.get_data().asnumpy()
check_libSVMIter_synthetic()
check_libSVMIter_news_data()
assertRaises(MXNetError, check_libSVMIter_exception)
def test_DataBatch():
from nose.tools import ok_
from mxnet.io import DataBatch
import re
batch = DataBatch(data=[mx.nd.ones((2, 3))])
ok_(re.match(
'DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)))
batch = DataBatch(data=[mx.nd.ones((2, 3)), mx.nd.ones(
(7, 8))], label=[mx.nd.ones((4, 5))])
ok_(re.match(
'DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch)))
def test_CSVIter():
def check_CSVIter_synthetic(dtype='float32'):
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
entry_str = '1'
if dtype is 'int32':
entry_str = '200000001'
if dtype is 'int64':
entry_str = '2147483648'
with open(data_path, 'w') as fout:
for i in range(1000):
fout.write(','.join([entry_str for _ in range(8*8)]) + '\n')
with open(label_path, 'w') as fout:
for i in range(1000):
fout.write('0\n')
data_train = mx.io.CSVIter(data_csv=data_path, data_shape=(8, 8),
label_csv=label_path, batch_size=100, dtype=dtype)
expected = mx.nd.ones((100, 8, 8), dtype=dtype) * int(entry_str)
for batch in iter(data_train):
data_batch = data_train.getdata()
assert_almost_equal(data_batch.asnumpy(), expected.asnumpy())
assert data_batch.asnumpy().dtype == expected.asnumpy().dtype
for dtype in ['int32', 'int64', 'float32']:
check_CSVIter_synthetic(dtype=dtype)
@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11359")
def test_ImageRecordIter_seed_augmentation():
get_cifar10()
seed_aug = 3
# check whether to get constant images after fixing seed_aug
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
rand_crop=True,
rand_mirror=True,
max_random_scale=1.3,
max_random_illumination=3,
max_rotate_angle=10,
random_l=50,
random_s=40,
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
rand_crop=True,
rand_mirror=True,
max_random_scale=1.3,
max_random_illumination=3,
max_rotate_angle=10,
random_l=50,
random_s=40,
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))
# check whether to get different images after change seed_aug
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
rand_crop=True,
rand_mirror=True,
max_random_scale=1.3,
max_random_illumination=3,
max_rotate_angle=10,
random_l=50,
random_s=40,
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))
# check whether seed_aug changes the iterator behavior
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))
if __name__ == "__main__":
test_NDArrayIter()
if h5py:
test_NDArrayIter_h5py()
test_MNISTIter()
test_Cifar10Rec()
test_LibSVMIter()
test_NDArrayIter_csr()
test_CSVIter()
test_ImageRecordIter_seed_augmentation()