这个模块包含加载和预获取批数据的类。
示例用法:
import image_tool from PIL import Image tool = image_tool.ImageTool() def image_transform(img_path): global tool return tool.load(img_path).resize_by_range( (112, 128)).random_crop( (96, 96)).flip().get() data = ImageBatchIter('train.txt', 3, image_transform, shuffle=True, delimiter=',', image_folder='images/', capacity=10) data.start() # imgs is a numpy array for a batch of images, # shape: batch_size, 3 (RGB), height, width imgs, labels = data.next() # convert numpy array back into images for idx in range(imgs.shape[0]): img = Image.fromarray(imgs[idx].astype(np.uint8).transpose(1, 2, 0), 'RGB') img.save('img%d.png' % idx) data.end()
迭代地从数据集中获取批数据。
参数: