blob: 7bd696840766fe7f392d0dc8cbd87930181bd21e [file] [log] [blame]
from __future__ import print_function
import sys, os
import argparse
import subprocess
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, '..'))
from dataset.pascal_voc import PascalVoc
from dataset.concat_db import ConcatDB
def load_pascal(image_set, year, devkit_path, shuffle=False):
"""
wrapper function for loading pascal voc dataset
Parameters:
----------
image_set : str
train, trainval...
year : str
2007, 2012 or combinations splitted by comma
devkit_path : str
root directory of dataset
shuffle : bool
whether to shuffle initial list
Returns:
----------
Imdb
"""
image_set = [y.strip() for y in image_set.split(',')]
assert image_set, "No image_set specified"
year = [y.strip() for y in year.split(',')]
assert year, "No year specified"
# make sure (# sets == # years)
if len(image_set) > 1 and len(year) == 1:
year = year * len(image_set)
if len(image_set) == 1 and len(year) > 1:
image_set = image_set * len(year)
assert len(image_set) == len(year), "Number of sets and year mismatch"
imdbs = []
for s, y in zip(image_set, year):
imdbs.append(PascalVoc(s, y, devkit_path, shuffle, is_train=True))
if len(imdbs) > 1:
return ConcatDB(imdbs, shuffle)
else:
return imdbs[0]
def parse_args():
parser = argparse.ArgumentParser(description='Prepare lists for dataset')
parser.add_argument('--dataset', dest='dataset', help='dataset to use',
default='pascal', type=str)
parser.add_argument('--year', dest='year', help='which year to use',
default='2007,2012', type=str)
parser.add_argument('--set', dest='set', help='train, val, trainval, test',
default='trainval', type=str)
parser.add_argument('--target', dest='target', help='output list file',
default=os.path.join(curr_path, '..', 'train.lst'),
type=str)
parser.add_argument('--root', dest='root_path', help='dataset root path',
default=os.path.join(curr_path, '..', 'data', 'VOCdevkit'),
type=str)
parser.add_argument('--shuffle', dest='shuffle', help='shuffle list',
type=bool, default=True)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.dataset == 'pascal':
db = load_pascal(args.set, args.year, args.root_path, args.shuffle)
db.save_imglist(args.target, root=args.root_path)
else:
raise NotImplementedError("No implementation for dataset: " + args.dataset)
print("List file {} generated...".format(args.target))
subprocess.check_call(["python",
os.path.join(curr_path, "../../../tools/im2rec.py"),
os.path.abspath(args.target), os.path.abspath(args.root_path),
"--shuffle", str(int(args.shuffle)), "--pack-label", "1"])
print("Record file {} generated...".format(args.target.split('.')[0] + '.rec'))