blob: 0ca8e1fd6653206c8c46306932caf2553bee609c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
""" data iterator for mnist """
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
import get_data
import mxnet as mx
def mnist_iterator(batch_size, input_shape):
"""return train and val iterators for mnist"""
# download data
get_data.GetMNIST_ubyte()
flat = False if len(input_shape) == 3 else True
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=flat)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
flat=flat)
return (train_dataiter, val_dataiter)