| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| try: |
| import pickle |
| except ImportError: |
| import cPickle as pickle |
| |
| import numpy as np |
| import os |
| import sys |
| |
| |
| def load_dataset(filepath): |
| with open(filepath, 'rb') as fd: |
| try: |
| cifar10 = pickle.load(fd, encoding='latin1') |
| except TypeError: |
| cifar10 = pickle.load(fd) |
| image = cifar10['data'].astype(dtype=np.uint8) |
| image = image.reshape((-1, 3, 32, 32)) |
| label = np.asarray(cifar10['labels'], dtype=np.uint8) |
| label = label.reshape(label.size, 1) |
| return image, label |
| |
| |
| def load_train_data(dir_path='/tmp/cifar-10-batches-py', num_batches=5): # need to save to specific local directories |
| labels = [] |
| batchsize = 10000 |
| images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8) |
| for did in range(1, num_batches + 1): |
| fname_train_data = dir_path + "/data_batch_{}".format(did) |
| image, label = load_dataset(check_dataset_exist(fname_train_data)) |
| images[(did - 1) * batchsize:did * batchsize] = image |
| labels.extend(label) |
| images = np.array(images, dtype=np.float32) |
| labels = np.array(labels, dtype=np.int32) |
| return images, labels |
| |
| |
| def load_test_data(dir_path='/tmp/cifar-10-batches-py'): # need to save to specific local directories |
| images, labels = load_dataset(check_dataset_exist(dir_path + "/test_batch")) |
| return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) |
| |
| |
| def check_dataset_exist(dirpath): |
| if not os.path.exists(dirpath): |
| print( |
| 'Please download the cifar10 dataset using python data/download_cifar10.py' |
| ) |
| sys.exit(0) |
| return dirpath |
| |
| |
| def normalize(train_x, val_x): |
| mean = [0.4914, 0.4822, 0.4465] |
| std = [0.2023, 0.1994, 0.2010] |
| train_x /= 255 |
| val_x /= 255 |
| for ch in range(0, 2): |
| train_x[:, ch, :, :] -= mean[ch] |
| train_x[:, ch, :, :] /= std[ch] |
| val_x[:, ch, :, :] -= mean[ch] |
| val_x[:, ch, :, :] /= std[ch] |
| return train_x, val_x |
| |
| def load(): |
| train_x, train_y = load_train_data() |
| val_x, val_y = load_test_data() |
| train_x, val_x = normalize(train_x, val_x) |
| train_y = train_y.flatten() |
| val_y = val_y.flatten() |
| return train_x, train_y, val_x, val_y |