数据(Data)

这个模块包含加载和预获取批数据的类。

示例用法:

import image_tool
from PIL import Image

tool = image_tool.ImageTool()

def image_transform(img_path):
    global tool
    return tool.load(img_path).resize_by_range(
        (112, 128)).random_crop(
        (96, 96)).flip().get()

data = ImageBatchIter('train.txt', 3,
                      image_transform, shuffle=True, delimiter=',',
                      image_folder='images/',
                      capacity=10)
data.start()
# imgs is a numpy array for a batch of images,
# shape: batch_size, 3 (RGB), height, width
imgs, labels = data.next()

# convert numpy array back into images
for idx in range(imgs.shape[0]):
    img = Image.fromarray(imgs[idx].astype(np.uint8).transpose(1, 2, 0),
                          'RGB')
    img.save('img%d.png' % idx)
data.end()

class singa.data.ImageBatchIter(img_list_file, batch_size, image_transform, shuffle=True, delimiter=' ', image_folder=None, capacity=10)

迭代地从数据集中获取批数据。

参数:

  • img_list_file (str) – 包含源数据的文件名;每行包含image_path_suffix和标签
  • batch_size (int) – 每个mini-bach包含的样本数目
  • image_transform – 图像增强函数;它接受完整的图像路径并输出一系列增强后的图像
  • shuffle (boolean) – 为真表示对列表做搅乱
  • delimiter (char) – image_path_suffix和标签之间的分割符, 例如空格或逗号
  • image_folder (boolean) – 图片路径的前缀
  • capacity (int) – 内部队列的最大mini-batch数目