| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: skip-file |
| import mxnet as mx |
| import mxnet.ndarray as nd |
| from mxnet.test_utils import * |
| from mxnet.base import MXNetError |
| import numpy as np |
| import os |
| import gzip |
| import pickle |
| import time |
| try: |
| import h5py |
| except ImportError: |
| h5py = None |
| import sys |
| from common import assertRaises |
| import pytest |
| from itertools import zip_longest |
| |
| @pytest.fixture(scope="session") |
| def cifar10(tmpdir_factory): |
| path = str(tmpdir_factory.mktemp('cifar')) |
| get_cifar10(path) |
| return path |
| |
| |
| def test_MNISTIter(tmpdir): |
| # prepare data |
| path = str(tmpdir) |
| get_mnist_ubyte(path) |
| |
| batch_size = 100 |
| train_dataiter = mx.io.MNISTIter( |
| image=os.path.join(path, 'train-images-idx3-ubyte'), |
| label=os.path.join(path, 'train-labels-idx1-ubyte'), |
| data_shape=(784,), |
| batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10) |
| # test_loop |
| nbatch = 60000 / batch_size |
| batch_count = 0 |
| for _ in train_dataiter: |
| batch_count += 1 |
| assert(nbatch == batch_count) |
| # test_reset |
| train_dataiter.reset() |
| train_dataiter.iter_next() |
| label_0 = train_dataiter.getlabel().asnumpy().flatten() |
| train_dataiter.iter_next() |
| train_dataiter.iter_next() |
| train_dataiter.iter_next() |
| train_dataiter.iter_next() |
| train_dataiter.reset() |
| train_dataiter.iter_next() |
| label_1 = train_dataiter.getlabel().asnumpy().flatten() |
| assert(sum(label_0 - label_1) == 0) |
| mx.nd.waitall() |
| |
| def test_Cifar10Rec(cifar10): |
| dataiter = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| rand_crop=False, |
| and_mirror=False, |
| shuffle=False, |
| data_shape=(3, 28, 28), |
| batch_size=100, |
| preprocess_threads=4, |
| prefetch_buffer=1) |
| labelcount = [0 for i in range(10)] |
| batchcount = 0 |
| for batch in dataiter: |
| npdata = batch.data[0].asnumpy().flatten().sum() |
| sys.stdout.flush() |
| batchcount += 1 |
| nplabel = batch.label[0].asnumpy() |
| for i in range(nplabel.shape[0]): |
| labelcount[int(nplabel[i])] += 1 |
| for i in range(10): |
| assert(labelcount[i] == 5000) |
| |
| @pytest.mark.parametrize('inter_method', [0,1,2,3,4,9,10]) |
| def test_inter_methods_in_augmenter(inter_method, cifar10): |
| dataiter = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| max_rotate_angle=45, |
| data_shape=(3, 28, 28), |
| batch_size=100, |
| inter_method=inter_method) |
| for _ in dataiter: |
| pass |
| |
| def test_image_iter_exception(cifar10): |
| with pytest.raises(MXNetError): |
| dataiter = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| rand_crop=False, |
| and_mirror=False, |
| shuffle=False, |
| data_shape=(5, 28, 28), |
| batch_size=100, |
| preprocess_threads=4, |
| prefetch_buffer=1) |
| labelcount = [0 for i in range(10)] |
| batchcount = 0 |
| for _ in dataiter: |
| pass |
| |
| def _init_NDArrayIter_data(data_type, is_image=False): |
| if is_image: |
| data = nd.random.uniform(0, 255, shape=(5000, 1, 28, 28)) |
| labels = nd.ones((5000, 1)) |
| return data, labels |
| if data_type == 'NDArray': |
| data = nd.ones((1000, 2, 2)) |
| labels = nd.ones((1000, 1)) |
| else: |
| data = np.ones((1000, 2, 2)) |
| labels = np.ones((1000, 1)) |
| for i in range(1000): |
| data[i] = i / 100 |
| labels[i] = i / 100 |
| return data, labels |
| |
| |
| def _test_last_batch_handle(data, labels=None, is_image=False): |
| # Test the three parameters 'pad', 'discard', 'roll_over' |
| last_batch_handle_list = ['pad', 'discard', 'roll_over'] |
| if labels is not None and not is_image and len(labels) != 0: |
| labelcount_list = [(124, 100), (100, 96), (100, 96)] |
| if is_image: |
| batch_count_list = [40, 39, 39] |
| else: |
| batch_count_list = [8, 7, 7] |
| |
| for idx in range(len(last_batch_handle_list)): |
| dataiter = mx.io.NDArrayIter( |
| data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx]) |
| batch_count = 0 |
| if labels is not None and len(labels) != 0 and not is_image: |
| labelcount = [0 for i in range(10)] |
| for batch in dataiter: |
| if len(data) == 2: |
| assert len(batch.data) == 2 |
| if labels is not None and len(labels) != 0: |
| if not is_image: |
| label = batch.label[0].asnumpy().flatten() |
| # check data if it matches corresponding labels |
| assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) |
| for i in range(label.shape[0]): |
| labelcount[int(label[i])] += 1 |
| else: |
| assert not batch.label, 'label is not empty list' |
| # keep the last batch of 'pad' to be used later |
| # to test first batch of roll_over in second iteration |
| batch_count += 1 |
| if last_batch_handle_list[idx] == 'pad' and \ |
| batch_count == batch_count_list[0]: |
| cache = batch.data[0].asnumpy() |
| # check if batchifying functionality work properly |
| if labels is not None and len(labels) != 0 and not is_image: |
| assert labelcount[0] == labelcount_list[idx][0], last_batch_handle_list[idx] |
| assert labelcount[8] == labelcount_list[idx][1], last_batch_handle_list[idx] |
| assert batch_count == batch_count_list[idx] |
| # roll_over option |
| dataiter.reset() |
| assert np.array_equal(dataiter.next().data[0].asnumpy(), cache) |
| |
| |
| def _test_shuffle(data, labels=None): |
| dataiter = mx.io.NDArrayIter(data, labels, 1, False) |
| batch_list = [] |
| for batch in dataiter: |
| # cache the original data |
| batch_list.append(batch.data[0].asnumpy()) |
| dataiter = mx.io.NDArrayIter(data, labels, 1, True) |
| idx_list = dataiter.idx |
| i = 0 |
| for batch in dataiter: |
| # check if each data point have been shuffled to corresponding positions |
| assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) |
| i += 1 |
| |
| |
| def _test_corner_case(): |
| data = np.arange(10) |
| data_iter = mx.io.NDArrayIter(data=data, batch_size=205, shuffle=False, last_batch_handle='pad') |
| expect = np.concatenate((np.tile(data, 20), np.arange(5))) |
| assert np.array_equal(data_iter.next().data[0].asnumpy(), expect) |
| |
| |
| def test_NDArrayIter(): |
| dtype_list = ['NDArray', 'ndarray'] |
| tested_data_type = [False, True] |
| for dtype in dtype_list: |
| for is_image in tested_data_type: |
| data, labels = _init_NDArrayIter_data(dtype, is_image) |
| _test_last_batch_handle(data, labels, is_image) |
| _test_last_batch_handle([data, data], labels, is_image) |
| _test_last_batch_handle(data=[data, data], is_image=is_image) |
| _test_last_batch_handle( |
| {'data1': data, 'data2': data}, labels, is_image) |
| _test_last_batch_handle(data={'data1': data, 'data2': data}, is_image=is_image) |
| _test_last_batch_handle(data, [], is_image) |
| _test_last_batch_handle(data=data, is_image=is_image) |
| _test_shuffle(data, labels) |
| _test_shuffle([data, data], labels) |
| _test_shuffle([data, data]) |
| _test_shuffle({'data1': data, 'data2': data}, labels) |
| _test_shuffle({'data1': data, 'data2': data}) |
| _test_shuffle(data, []) |
| _test_shuffle(data) |
| _test_corner_case() |
| |
| |
| def test_NDArrayIter_h5py(): |
| if not h5py: |
| return |
| |
| data, labels = _init_NDArrayIter_data('ndarray') |
| |
| try: |
| os.remove('ndarraytest.h5') |
| except OSError: |
| pass |
| with h5py.File('ndarraytest.h5') as f: |
| f.create_dataset('data', data=data) |
| f.create_dataset('label', data=labels) |
| |
| _test_last_batch_handle(f['data'], f['label']) |
| _test_last_batch_handle(f['data'], []) |
| _test_last_batch_handle(f['data']) |
| try: |
| os.remove("ndarraytest.h5") |
| except OSError: |
| pass |
| |
| |
| def _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, csr_iter_None, num_rows, batch_size): |
| num_batch = 0 |
| for _, batch_empty_list, batch_empty_None in zip(csr_iter, csr_iter_empty_list, csr_iter_None): |
| assert not batch_empty_list.label, 'label is not empty list' |
| assert not batch_empty_None.label, 'label is not empty list' |
| num_batch += 1 |
| |
| assert(num_batch == num_rows // batch_size) |
| assertRaises(StopIteration, csr_iter.next) |
| assertRaises(StopIteration, csr_iter_empty_list.next) |
| assertRaises(StopIteration, csr_iter_None.next) |
| |
| |
| def test_NDArrayIter_csr(): |
| # creating toy data |
| num_rows = rnd.randint(5, 15) |
| num_cols = rnd.randint(1, 20) |
| batch_size = rnd.randint(1, num_rows) |
| shape = (num_rows, num_cols) |
| csr, _ = rand_sparse_ndarray(shape, 'csr') |
| dns = csr.asnumpy() |
| |
| # CSRNDArray or scipy.sparse.csr_matrix with last_batch_handle not equal to 'discard' will throw NotImplementedError |
| assertRaises(NotImplementedError, mx.io.NDArrayIter, |
| {'data': csr}, dns, batch_size) |
| try: |
| import scipy.sparse as spsp |
| train_data = spsp.csr_matrix(dns) |
| assertRaises(NotImplementedError, mx.io.NDArrayIter, |
| {'data': train_data}, dns, batch_size) |
| except ImportError: |
| pass |
| |
| # scipy.sparse.csr_matrix with shuffle |
| csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size, |
| shuffle=True, last_batch_handle='discard')) |
| csr_iter_empty_list = iter(mx.io.NDArrayIter({'data': train_data}, [], batch_size, |
| shuffle=True, last_batch_handle='discard')) |
| csr_iter_None = iter(mx.io.NDArrayIter({'data': train_data}, None, batch_size, |
| shuffle=True, last_batch_handle='discard')) |
| _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, |
| csr_iter_None, num_rows, batch_size) |
| |
| # CSRNDArray with shuffle |
| csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size, |
| shuffle=True, last_batch_handle='discard')) |
| csr_iter_empty_list = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, [], batch_size, |
| shuffle=True, last_batch_handle='discard')) |
| csr_iter_None = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, None, batch_size, |
| shuffle=True, last_batch_handle='discard')) |
| _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, |
| csr_iter_None, num_rows, batch_size) |
| |
| # make iterators |
| csr_iter = iter(mx.io.NDArrayIter( |
| csr, csr, batch_size, last_batch_handle='discard')) |
| begin = 0 |
| for batch in csr_iter: |
| expected = np.zeros((batch_size, num_cols)) |
| end = begin + batch_size |
| expected[:num_rows - begin] = dns[begin:end] |
| if end > num_rows: |
| expected[num_rows - begin:] = dns[0:end - num_rows] |
| assert_almost_equal(batch.data[0].asnumpy(), expected) |
| begin += batch_size |
| |
| |
| def test_LibSVMIter(tmpdir): |
| |
| def check_libSVMIter_synthetic(): |
| data_path = os.path.join(str(tmpdir), 'data.t') |
| label_path = os.path.join(str(tmpdir), 'label.t') |
| with open(data_path, 'w') as fout: |
| fout.write('1.0 0:0.5 2:1.2\n') |
| fout.write('-2.0\n') |
| fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') |
| fout.write('4 2:-1.2\n') |
| |
| with open(label_path, 'w') as fout: |
| fout.write('1.0\n') |
| fout.write('-2.0 0:0.125\n') |
| fout.write('-3.0 2:1.2\n') |
| fout.write('4 1:1.0 2:-1.2\n') |
| |
| data_dir = os.path.join(str(tmpdir), 'data') |
| data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path, |
| data_shape=(3, ), label_shape=(3, ), batch_size=3) |
| |
| first = mx.nd.array([[0.5, 0., 1.2], [0., 0., 0.], [0.6, 2.4, 1.2]]) |
| second = mx.nd.array([[0., 0., -1.2], [0.5, 0., 1.2], [0., 0., 0.]]) |
| i = 0 |
| for _ in iter(data_train): |
| expected = first.asnumpy() if i == 0 else second.asnumpy() |
| data = data_train.getdata() |
| data.check_format(True) |
| assert_almost_equal(data.asnumpy(), expected) |
| i += 1 |
| |
| def check_libSVMIter_news_data(): |
| news_metadata = { |
| 'name': 'news20.t', |
| 'origin_name': 'news20.t.bz2', |
| 'url': "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/news20.t.bz2", |
| 'feature_dim': 62060 + 1, |
| 'num_classes': 20, |
| 'num_examples': 3993, |
| } |
| batch_size = 33 |
| num_examples = news_metadata['num_examples'] |
| data_dir = os.path.join(str(tmpdir), 'data') |
| get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'], |
| news_metadata['origin_name']) |
| path = os.path.join(data_dir, news_metadata['name']) |
| data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],), |
| batch_size=batch_size) |
| for _ in range(2): |
| num_batches = 0 |
| for batch in data_train: |
| # check the range of labels |
| data = batch.data[0] |
| label = batch.label[0] |
| data.check_format(True) |
| assert(np.sum(label.asnumpy() > 20) == 0) |
| assert(np.sum(label.asnumpy() <= 0) == 0) |
| num_batches += 1 |
| expected_num_batches = num_examples / batch_size |
| assert(num_batches == int(expected_num_batches)), num_batches |
| data_train.reset() |
| |
| def check_libSVMIter_exception(): |
| data_path = os.path.join(str(tmpdir), 'data.t') |
| label_path = os.path.join(str(tmpdir), 'label.t') |
| with open(data_path, 'w') as fout: |
| fout.write('1.0 0:0.5 2:1.2\n') |
| fout.write('-2.0\n') |
| # Below line has a neg indice. Should throw an exception |
| fout.write('-3.0 -1:0.6 1:2.4 2:1.2\n') |
| fout.write('4 2:-1.2\n') |
| |
| with open(label_path, 'w') as fout: |
| fout.write('1.0\n') |
| fout.write('-2.0 0:0.125\n') |
| fout.write('-3.0 2:1.2\n') |
| fout.write('4 1:1.0 2:-1.2\n') |
| data_dir = os.path.join(str(tmpdir), 'data') |
| data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path, |
| data_shape=(3, ), label_shape=(3, ), batch_size=3) |
| for _ in iter(data_train): |
| data_train.get_data().asnumpy() |
| |
| check_libSVMIter_synthetic() |
| check_libSVMIter_news_data() |
| assertRaises(MXNetError, check_libSVMIter_exception) |
| |
| |
| def test_DataBatch(): |
| from mxnet.io import DataBatch |
| import re |
| batch = DataBatch(data=[mx.nd.ones((2, 3))]) |
| assert re.match( |
| r'DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)) |
| batch = DataBatch(data=[mx.nd.ones((2, 3)), mx.nd.ones( |
| (7, 8))], label=[mx.nd.ones((4, 5))]) |
| assert re.match( |
| r'DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch)) |
| |
| |
| @pytest.mark.skip(reason="https://github.com/apache/incubator-mxnet/issues/18382") |
| def test_CSVIter(tmpdir): |
| def check_CSVIter_synthetic(dtype='float32'): |
| data_path = os.path.join(str(tmpdir), 'data.t') |
| label_path = os.path.join(str(tmpdir), 'label.t') |
| entry_str = '1' |
| if dtype is 'int32': |
| entry_str = '200000001' |
| if dtype is 'int64': |
| entry_str = '2147483648' |
| with open(data_path, 'w') as fout: |
| for _ in range(1000): |
| fout.write(','.join([entry_str for _ in range(8*8)]) + '\n') |
| with open(label_path, 'w') as fout: |
| for _ in range(1000): |
| fout.write('0\n') |
| |
| data_train = mx.io.CSVIter(data_csv=data_path, data_shape=(8, 8), |
| label_csv=label_path, batch_size=100, dtype=dtype) |
| expected = mx.nd.ones((100, 8, 8), dtype=dtype) * int(entry_str) |
| for _ in iter(data_train): |
| data_batch = data_train.getdata() |
| assert_almost_equal(data_batch.asnumpy(), expected.asnumpy()) |
| assert data_batch.asnumpy().dtype == expected.asnumpy().dtype |
| |
| for dtype in ['int32', 'int64', 'float32']: |
| check_CSVIter_synthetic(dtype=dtype) |
| |
| def test_ImageRecordIter_seed_augmentation(cifar10): |
| seed_aug = 3 |
| |
| def assert_dataiter_items_equals(dataiter1, dataiter2): |
| """ |
| Asserts that two data iterators have the same numbner of batches, |
| that the batches have the same number of items, and that the items |
| are the equal. |
| """ |
| for batch1, batch2 in zip_longest(dataiter1, dataiter2): |
| |
| # ensure iterators contain the same number of batches |
| # zip_longest will return None if on of the iterators have run out of batches |
| assert batch1 and batch2, 'The iterators do not contain the same number of batches' |
| |
| # ensure batches are of same length |
| assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length' |
| |
| # ensure batch data is the same |
| for i in range(0, len(batch1.data)): |
| data1 = batch1.data[i].asnumpy().astype(np.uint8) |
| data2 = batch2.data[i].asnumpy().astype(np.uint8) |
| assert(np.array_equal(data1, data2)) |
| |
| def assert_dataiter_items_not_equals(dataiter1, dataiter2): |
| """ |
| Asserts that two data iterators have the same numbner of batches, |
| that the batches have the same number of items, and that the items |
| are the _not_ equal. |
| """ |
| for batch1, batch2 in zip_longest(dataiter1, dataiter2): |
| |
| # ensure iterators are of same length |
| # zip_longest will return None if on of the iterators have run out of batches |
| assert batch1 and batch2, 'The iterators do not contain the same number of batches' |
| |
| # ensure batches are of same length |
| assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length' |
| |
| # ensure batch data is the same |
| for i in range(0, len(batch1.data)): |
| data1 = batch1.data[i].asnumpy().astype(np.uint8) |
| data2 = batch2.data[i].asnumpy().astype(np.uint8) |
| if not np.array_equal(data1, data2): |
| return |
| assert False, 'Expected data iterators to be different, but they are the same' |
| |
| # check whether to get constant images after fixing seed_aug |
| dataiter1 = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| shuffle=False, |
| data_shape=(3, 28, 28), |
| batch_size=3, |
| rand_crop=True, |
| rand_mirror=True, |
| max_random_scale=1.3, |
| max_random_illumination=3, |
| max_rotate_angle=10, |
| random_l=50, |
| random_s=40, |
| random_h=10, |
| max_shear_ratio=2, |
| seed_aug=seed_aug) |
| |
| dataiter2 = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| shuffle=False, |
| data_shape=(3, 28, 28), |
| batch_size=3, |
| rand_crop=True, |
| rand_mirror=True, |
| max_random_scale=1.3, |
| max_random_illumination=3, |
| max_rotate_angle=10, |
| random_l=50, |
| random_s=40, |
| random_h=10, |
| max_shear_ratio=2, |
| seed_aug=seed_aug) |
| |
| assert_dataiter_items_equals(dataiter1, dataiter2) |
| |
| # check whether to get different images after change seed_aug |
| dataiter1.reset() |
| dataiter2 = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| shuffle=False, |
| data_shape=(3, 28, 28), |
| batch_size=3, |
| rand_crop=True, |
| rand_mirror=True, |
| max_random_scale=1.3, |
| max_random_illumination=3, |
| max_rotate_angle=10, |
| random_l=50, |
| random_s=40, |
| random_h=10, |
| max_shear_ratio=2, |
| seed_aug=seed_aug+1) |
| |
| assert_dataiter_items_not_equals(dataiter1, dataiter2) |
| |
| # check whether seed_aug changes the iterator behavior |
| dataiter1 = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| shuffle=False, |
| data_shape=(3, 28, 28), |
| batch_size=3, |
| seed_aug=seed_aug) |
| |
| dataiter2 = mx.io.ImageRecordIter( |
| path_imgrec=os.path.join(cifar10, 'cifar', 'train.rec'), |
| mean_img=os.path.join(cifar10, 'cifar', 'cifar10_mean.bin'), |
| shuffle=False, |
| data_shape=(3, 28, 28), |
| batch_size=3, |
| seed_aug=seed_aug) |
| |
| assert_dataiter_items_equals(dataiter1, dataiter2) |