blob: 753edbf66b25e84e704095cb8c2df56d1c461a8f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
'''
This module includes classes for loading and prefetching data batches.
Example usage::
import image_tool
from PIL import Image
tool = image_tool.ImageTool()
def image_transform(img_path):
global tool
return tool.load(img_path).resize_by_range(
(112, 128)).random_crop(
(96, 96)).flip().get()
data = ImageBatchIter('train.txt', 3,
image_transform, shuffle=True, delimiter=',',
image_folder='images/',
capacity=10)
data.start()
# imgs is a numpy array for a batch of images,
# shape: batch_size, 3 (RGB), height, width
imgs, labels = data.next()
# convert numpy array back into images
for idx in range(imgs.shape[0]):
img = Image.fromarray(imgs[idx].astype(np.uint8).transpose(1, 2, 0),
'RGB')
img.save('img%d.png' % idx)
data.end()
'''
from __future__ import print_function
from __future__ import absolute_import
from builtins import range
from builtins import object
import os
import random
import time
from multiprocessing import Process, Queue
import numpy as np
class ImageBatchIter(object):
'''Utility for iterating over an image dataset to get mini-batches.
Args:
img_list_file(str): name of the file containing image meta data; each
line consists of image_path_suffix delimiter meta_info,
where meta info could be label index or label strings, etc.
meta_info should not contain the delimiter. If the meta_info
of each image is just the label index, then we will parse the
label index into a numpy array with length=batchsize
(for compatibility); otherwise, we return a list of meta_info;
if meta info is available, we return a list of None.
batch_size(int): num of samples in one mini-batch
image_transform: a function for image augmentation; it accepts the full
image path and outputs a list of augmented images.
shuffle(boolean): True for shuffling images in the list
delimiter(char): delimiter between image_path_suffix and label, e.g.,
space or comma
image_folder(boolean): prefix of the image path
capacity(int): the max num of mini-batches in the internal queue.
'''
def __init__(self, img_list_file, batch_size, image_transform,
shuffle=True, delimiter=' ', image_folder=None, capacity=10):
self.img_list_file = img_list_file
self.queue = Queue(capacity)
self.batch_size = batch_size
self.image_transform = image_transform
self.shuffle = shuffle
self.delimiter = delimiter
self.image_folder = image_folder
self.stop = False
self.p = None
with open(img_list_file, 'r') as fd:
self.num_samples = len(fd.readlines())
def start(self):
self.p = Process(target=self.run)
self.p.start()
return
def __next__(self):
assert self.p is not None, 'call start before next'
while self.queue.empty():
time.sleep(0.1)
x, y = self.queue.get() # dequeue one mini-batch
return x, y
def stop(self):
self.end();
def end(self):
if self.p is not None:
self.stop = True
time.sleep(0.1)
self.p.terminate()
def run(self):
img_list = []
is_label_index = True
for line in open(self.img_list_file, 'r'):
item = line.strip('\n').split(self.delimiter)
if len(item) < 2:
is_label_index = False
img_list.append((item[0].strip(), None))
else:
if not item[1].strip().isdigit():
# the meta info is not label index
is_label_index = False
img_list.append((item[0].strip(), item[1].strip()))
index = 0 # index for the image
if self.shuffle:
random.shuffle(img_list)
while not self.stop:
if not self.queue.full():
x, y = [], []
i = 0
while i < self.batch_size:
img_path, img_meta = img_list[index]
aug_images = self.image_transform(
os.path.join(self.image_folder, img_path))
assert i + len(aug_images) <= self.batch_size, \
'too many images (%d) in a batch (%d)' % \
(i + len(aug_images), self.batch_size)
for img in aug_images:
ary = np.asarray(img.convert('RGB'), dtype=np.float32)
x.append(ary.transpose(2, 0, 1))
if is_label_index:
y.append(int(img_meta))
else:
y.append(img_meta)
i += 1
index += 1
if index == self.num_samples:
index = 0 # reset to the first image
if self.shuffle:
random.shuffle(img_list)
# enqueue one mini-batch
if is_label_index:
self.queue.put((np.asarray(x), np.asarray(y, dtype=np.int32)))
else:
self.queue.put((np.asarray(x), y))
else:
time.sleep(0.1)
return
if __name__ == '__main__':
from . import image_tool
from PIL import Image
tool = image_tool.ImageTool()
def image_transform(img_path):
global tool
return tool.load(img_path).resize_by_range(
(112, 128)).random_crop(
(96, 96)).flip().get()
data = ImageBatchIter('train.txt', 3,
image_transform, shuffle=False, delimiter=',',
image_folder='images/',
capacity=10)
data.start()
imgs, labels = next(data)
print(labels)
for idx in range(imgs.shape[0]):
img = Image.fromarray(imgs[idx].astype(np.uint8).transpose(1, 2, 0),
'RGB')
img.save('img%d.png' % idx)
data.end()