| # -*- coding: utf-8 -*- | |
| from __future__ import print_function | |
| import os | |
| import sys | |
| curr_path = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(curr_path, "../python")) | |
| import mxnet as mx | |
| import random | |
| import argparse | |
| import cv2 | |
| import time | |
| import traceback | |
| try: | |
| import multiprocessing | |
| except ImportError: | |
| multiprocessing = None | |
| def list_image(root, recursive, exts): | |
| i = 0 | |
| if recursive: | |
| cat = {} | |
| for path, dirs, files in os.walk(root, followlinks=True): | |
| dirs.sort() | |
| files.sort() | |
| for fname in files: | |
| fpath = os.path.join(path, fname) | |
| suffix = os.path.splitext(fname)[1].lower() | |
| if os.path.isfile(fpath) and (suffix in exts): | |
| if path not in cat: | |
| cat[path] = len(cat) | |
| yield (i, os.path.relpath(fpath, root), cat[path]) | |
| i += 1 | |
| for k, v in sorted(cat.items(), key=lambda x: x[1]): | |
| print(os.path.relpath(k, root), v) | |
| else: | |
| for fname in sorted(os.listdir(root)): | |
| fpath = os.path.join(root, fname) | |
| suffix = os.path.splitext(fname)[1].lower() | |
| if os.path.isfile(fpath) and (suffix in exts): | |
| yield (i, os.path.relpath(fpath, root), 0) | |
| i += 1 | |
| def write_list(path_out, image_list): | |
| with open(path_out, 'w') as fout: | |
| for i, item in enumerate(image_list): | |
| line = '%d\t' % item[0] | |
| for j in item[2:]: | |
| line += '%f\t' % j | |
| line += '%s\n' % item[1] | |
| fout.write(line) | |
| def make_list(args): | |
| image_list = list_image(args.root, args.recursive, args.exts) | |
| image_list = list(image_list) | |
| if args.shuffle is True: | |
| random.seed(100) | |
| random.shuffle(image_list) | |
| N = len(image_list) | |
| chunk_size = (N + args.chunks - 1) / args.chunks | |
| for i in xrange(args.chunks): | |
| chunk = image_list[i * chunk_size:(i + 1) * chunk_size] | |
| if args.chunks > 1: | |
| str_chunk = '_%d' % i | |
| else: | |
| str_chunk = '' | |
| sep = int(chunk_size * args.train_ratio) | |
| sep_test = int(chunk_size * args.test_ratio) | |
| if args.train_ratio == 1.0: | |
| write_list(args.prefix + str_chunk + '.lst', chunk) | |
| else: | |
| if args.test_ratio: | |
| write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) | |
| if args.train_ratio + args.test_ratio < 1.0: | |
| write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:]) | |
| write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep]) | |
| def read_list(path_in): | |
| with open(path_in) as fin: | |
| while True: | |
| line = fin.readline() | |
| if not line: | |
| break | |
| line = [i.strip() for i in line.strip().split('\t')] | |
| line_len = len(line) | |
| if line_len < 3: | |
| print('lst should at least has three parts, but only has %s parts for %s' %(line_len, line)) | |
| continue | |
| try: | |
| item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] | |
| except Exception, e: | |
| print('Parsing lst met error for %s, detail: %s' %(line, e)) | |
| continue | |
| yield item | |
| def image_encode(args, i, item, q_out): | |
| fullpath = os.path.join(args.root, item[1]) | |
| if len(item) > 3 and args.pack_label: | |
| header = mx.recordio.IRHeader(0, item[2:], item[0], 0) | |
| else: | |
| header = mx.recordio.IRHeader(0, item[2], item[0], 0) | |
| if args.pass_through: | |
| try: | |
| with open(fullpath) as fin: | |
| img = fin.read() | |
| s = mx.recordio.pack(header, img) | |
| q_out.put((i, s, item)) | |
| except Exception, e: | |
| traceback.print_exc() | |
| print('pack_img error:', item[1], e) | |
| q_out.put((i, None, item)) | |
| return | |
| try: | |
| img = cv2.imread(fullpath, args.color) | |
| except: | |
| traceback.print_exc() | |
| print('imread error trying to load file: %s ' % fullpath) | |
| q_out.put((i, None, item)) | |
| return | |
| if img is None: | |
| print('imread read blank (None) image for file: %s' % fullpath) | |
| q_out.put((i, None, item)) | |
| return | |
| if args.center_crop: | |
| if img.shape[0] > img.shape[1]: | |
| margin = (img.shape[0] - img.shape[1]) / 2; | |
| img = img[margin:margin + img.shape[1], :] | |
| else: | |
| margin = (img.shape[1] - img.shape[0]) / 2; | |
| img = img[:, margin:margin + img.shape[0]] | |
| if args.resize: | |
| if img.shape[0] > img.shape[1]: | |
| newsize = (args.resize, img.shape[0] * args.resize / img.shape[1]) | |
| else: | |
| newsize = (img.shape[1] * args.resize / img.shape[0], args.resize) | |
| img = cv2.resize(img, newsize) | |
| try: | |
| s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) | |
| q_out.put((i, s, item)) | |
| except Exception, e: | |
| traceback.print_exc() | |
| print('pack_img error on file: %s' % fullpath, e) | |
| q_out.put((i, None, item)) | |
| return | |
| def read_worker(args, q_in, q_out): | |
| while True: | |
| deq = q_in.get() | |
| if deq is None: | |
| break | |
| i, item = deq | |
| image_encode(args, i, item, q_out) | |
| def write_worker(q_out, fname, working_dir): | |
| pre_time = time.time() | |
| count = 0 | |
| fname = os.path.basename(fname) | |
| fname_rec = os.path.splitext(fname)[0] + '.rec' | |
| fname_idx = os.path.splitext(fname)[0] + '.idx' | |
| record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), | |
| os.path.join(working_dir, fname_rec), 'w') | |
| buf = {} | |
| more = True | |
| while more: | |
| deq = q_out.get() | |
| if deq is not None: | |
| i, s, item = deq | |
| buf[i] = (s, item) | |
| else: | |
| more = False | |
| while count in buf: | |
| s, item = buf[count] | |
| del buf[count] | |
| if s is not None: | |
| record.write_idx(item[0], s) | |
| if count % 1000 == 0: | |
| cur_time = time.time() | |
| print('time:', cur_time - pre_time, ' count:', count) | |
| pre_time = cur_time | |
| count += 1 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| description='Create an image list or \ | |
| make a record database by reading from an image list') | |
| parser.add_argument('prefix', help='prefix of input/output lst and rec files.') | |
| parser.add_argument('root', help='path to folder containing images.') | |
| cgroup = parser.add_argument_group('Options for creating image lists') | |
| cgroup.add_argument('--list', type=bool, default=False, | |
| help='If this is set im2rec will create image list(s) by traversing root folder\ | |
| and output to <prefix>.lst.\ | |
| Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec') | |
| cgroup.add_argument('--exts', type=list, default=['.jpeg', '.jpg'], | |
| help='list of acceptable image extensions.') | |
| cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.') | |
| cgroup.add_argument('--train-ratio', type=float, default=1.0, | |
| help='Ratio of images to use for training.') | |
| cgroup.add_argument('--test-ratio', type=float, default=0, | |
| help='Ratio of images to use for testing.') | |
| cgroup.add_argument('--recursive', type=bool, default=False, | |
| help='If true recursively walk through subdirs and assign an unique label\ | |
| to images in each folder. Otherwise only include images in the root folder\ | |
| and give them label 0.') | |
| cgroup.add_argument('--shuffle', type=bool, default=True, help='If this is set as True, \ | |
| im2rec will randomize the image order in <prefix>.lst') | |
| rgroup = parser.add_argument_group('Options for creating database') | |
| rgroup.add_argument('--pass-through', type=bool, default=False, | |
| help='whether to skip transformation and save image as is') | |
| rgroup.add_argument('--resize', type=int, default=0, | |
| help='resize the shorter edge of image to the newsize, original images will\ | |
| be packed by default.') | |
| rgroup.add_argument('--center-crop', type=bool, default=False, | |
| help='specify whether to crop the center image to make it rectangular.') | |
| rgroup.add_argument('--quality', type=int, default=95, | |
| help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9') | |
| rgroup.add_argument('--num-thread', type=int, default=1, | |
| help='number of thread to use for encoding. order of images will be different\ | |
| from the input list if >1. the input list will be modified to match the\ | |
| resulting order.') | |
| rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1], | |
| help='specify the color mode of the loaded image.\ | |
| 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ | |
| 0: Loads image in grayscale mode.\ | |
| -1:Loads image as such including alpha channel.') | |
| rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], | |
| help='specify the encoding of the images.') | |
| rgroup.add_argument('--pack-label', type=bool, default=False, | |
| help='Whether to also pack multi dimensional label in the record file') | |
| args = parser.parse_args() | |
| args.prefix = os.path.abspath(args.prefix) | |
| args.root = os.path.abspath(args.root) | |
| return args | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| if args.list: | |
| make_list(args) | |
| else: | |
| if os.path.isdir(args.prefix): | |
| working_dir = args.prefix | |
| else: | |
| working_dir = os.path.dirname(args.prefix) | |
| files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir) | |
| if os.path.isfile(os.path.join(working_dir, fname))] | |
| count = 0 | |
| for fname in files: | |
| if fname.startswith(args.prefix) and fname.endswith('.lst'): | |
| print('Creating .rec file from', fname, 'in', working_dir) | |
| count += 1 | |
| image_list = read_list(fname) | |
| # -- write_record -- # | |
| if args.num_thread > 1 and multiprocessing is not None: | |
| q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)] | |
| q_out = multiprocessing.Queue(1024) | |
| read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \ | |
| for i in range(args.num_thread)] | |
| for p in read_process: | |
| p.start() | |
| write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir)) | |
| write_process.start() | |
| for i, item in enumerate(image_list): | |
| q_in[i % len(q_in)].put((i, item)) | |
| for q in q_in: | |
| q.put(None) | |
| for p in read_process: | |
| p.join() | |
| q_out.put(None) | |
| write_process.join() | |
| else: | |
| print('multiprocessing not available, fall back to single threaded encoding') | |
| import Queue | |
| q_out = Queue.Queue() | |
| fname = os.path.basename(fname) | |
| fname_rec = os.path.splitext(fname)[0] + '.rec' | |
| fname_idx = os.path.splitext(fname)[0] + '.idx' | |
| record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), | |
| os.path.join(working_dir, fname_rec), 'w') | |
| cnt = 0 | |
| pre_time = time.time() | |
| for i, item in enumerate(image_list): | |
| image_encode(args, i, item, q_out) | |
| if q_out.empty(): | |
| continue | |
| _, s, _ = q_out.get() | |
| record.write_idx(item[0], s) | |
| if cnt % 1000 == 0: | |
| cur_time = time.time() | |
| print('time:', cur_time - pre_time, ' count:', cnt) | |
| pre_time = cur_time | |
| cnt += 1 | |
| if not count: | |
| print('Did not find and list file with prefix %s'%args.prefix) |