blob: 133188d4af3b044b8d6214d27534baf80ceff89b [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CIFAR-10 data set.
See http://www.cs.toronto.edu/~kriz/cifar.html.
"""
import os
import tensorflow as tf
HEIGHT = 32
WIDTH = 32
DEPTH = 3
class Cifar10DataSet(object):
"""Cifar10 data set.
Described by http://www.cs.toronto.edu/~kriz/cifar.html.
"""
def __init__(self, data_dir, subset="train", use_distortion=True):
self.data_dir = data_dir
self.subset = subset
self.use_distortion = use_distortion
def get_filenames(self):
if self.subset in ["train", "validation", "eval"]:
return [os.path.join(self.data_dir, self.subset + ".tfrecords")]
else:
raise ValueError('Invalid data subset "%s"' % self.subset)
def parser(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
features = tf.parse_single_example(
serialized_example,
features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
},
)
image = tf.decode_raw(features["image"], tf.uint8)
image.set_shape([DEPTH * HEIGHT * WIDTH])
# Reshape from [depth * height * width] to [depth, height, width].
image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32)
label = tf.cast(features["label"], tf.int32)
# Custom preprocessing.
image = self.preprocess(image)
return image, label
def make_batch(self, batch_size):
"""Read the images and labels from 'filenames'."""
filenames = self.get_filenames()
# Repeat infinitely.
dataset = tf.data.TFRecordDataset(filenames).repeat()
# Parse records.
dataset = dataset.map(self.parser, num_parallel_calls=batch_size)
# Potentially shuffle records.
if self.subset == "train":
min_queue_examples = int(Cifar10DataSet.num_examples_per_epoch(self.subset) * 0.4)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
# Batch it up.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
return image_batch, label_batch
def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout."""
if self.subset == "train" and self.use_distortion:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
image = tf.image.random_flip_left_right(image)
return image
@staticmethod
def num_examples_per_epoch(subset="train"):
if subset == "train":
return 45000
elif subset == "validation":
return 5000
elif subset == "eval":
return 10000
else:
raise ValueError('Invalid data subset "%s"' % subset)