blob: c031f04d4fe78892e5f2647468c634eddcc6d74d [file] [log] [blame]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import sys, os
import argparse
import subprocess
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, '..'))
from dataset.pascal_voc import PascalVoc
from dataset.mscoco import Coco
from dataset.concat_db import ConcatDB
def load_pascal(image_set, year, devkit_path, shuffle=False):
"""
wrapper function for loading pascal voc dataset
Parameters:
----------
image_set : str
train, trainval...
year : str
2007, 2012 or combinations splitted by comma
devkit_path : str
root directory of dataset
shuffle : bool
whether to shuffle initial list
Returns:
----------
Imdb
"""
image_set = [y.strip() for y in image_set.split(',')]
assert image_set, "No image_set specified"
year = [y.strip() for y in year.split(',')]
assert year, "No year specified"
# make sure (# sets == # years)
if len(image_set) > 1 and len(year) == 1:
year = year * len(image_set)
if len(image_set) == 1 and len(year) > 1:
image_set = image_set * len(year)
assert len(image_set) == len(year), "Number of sets and year mismatch"
imdbs = []
for s, y in zip(image_set, year):
imdbs.append(PascalVoc(s, y, devkit_path, shuffle, is_train=True))
if len(imdbs) > 1:
return ConcatDB(imdbs, shuffle)
else:
return imdbs[0]
def load_coco(image_set, dirname, shuffle=False):
"""
wrapper function for loading ms coco dataset
Parameters:
----------
image_set : str
train2014, val2014, valminusminival2014, minival2014
dirname: str
root dir for coco
shuffle: boolean
initial shuffle
"""
anno_files = ['instances_' + y.strip() + '.json' for y in image_set.split(',')]
assert anno_files, "No image set specified"
imdbs = []
for af in anno_files:
af_path = os.path.join(dirname, 'annotations', af)
imdbs.append(Coco(af_path, dirname, shuffle=shuffle))
if len(imdbs) > 1:
return ConcatDB(imdbs, shuffle)
else:
return imdbs[0]
def parse_args():
parser = argparse.ArgumentParser(description='Prepare lists for dataset')
parser.add_argument('--dataset', dest='dataset', help='dataset to use',
default='pascal', type=str)
parser.add_argument('--year', dest='year', help='which year to use',
default='2007,2012', type=str)
parser.add_argument('--set', dest='set', help='train, val, trainval, test',
default='trainval', type=str)
parser.add_argument('--target', dest='target', help='output list file',
default=os.path.join(curr_path, '..', 'train.lst'),
type=str)
parser.add_argument('--root', dest='root_path', help='dataset root path',
default=os.path.join(curr_path, '..', 'data', 'VOCdevkit'),
type=str)
parser.add_argument('--no-shuffle', dest='shuffle', help='shuffle list',
action='store_false')
parser.add_argument('--num-thread', dest='num_thread', type=int, default=1,
help='number of thread to use while runing im2rec.py')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.dataset == 'pascal':
db = load_pascal(args.set, args.year, args.root_path, args.shuffle)
print("saving list to disk...")
db.save_imglist(args.target, root=args.root_path)
elif args.dataset == 'coco':
db = load_coco(args.set, args.root_path, args.shuffle)
print("saving list to disk...")
db.save_imglist(args.target, root=args.root_path)
else:
raise NotImplementedError("No implementation for dataset: " + args.dataset)
print("List file {} generated...".format(args.target))
cmd_arguments = ["python",
os.path.join(curr_path, "../../../tools/im2rec.py"),
os.path.abspath(args.target), os.path.abspath(args.root_path),
"--pack-label", "--num-thread", str(args.num_thread)]
if not args.shuffle:
cmd_arguments.append("--no-shuffle")
subprocess.check_call(cmd_arguments)
print("Record file {} generated...".format(args.target.split('.')[0] + '.rec'))