blob: 2989bc02a4f7659e46287d7f1b35ad468eb53541 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import mxnet as mx
from rcnn.logger import logger
from rcnn.config import config, default, generate_config
from rcnn.tools.test_rcnn import test_rcnn
def parse_args():
parser = argparse.ArgumentParser(description='Test a Faster R-CNN network')
# general
parser.add_argument('--network', help='network name', default=default.network, type=str)
parser.add_argument('--dataset', help='dataset name', default=default.dataset, type=str)
args, rest = parser.parse_known_args()
generate_config(args.network, args.dataset)
parser.add_argument('--image_set', help='image_set name', default=default.test_image_set, type=str)
parser.add_argument('--root_path', help='output data folder', default=default.root_path, type=str)
parser.add_argument('--dataset_path', help='dataset path', default=default.dataset_path, type=str)
# testing
parser.add_argument('--prefix', help='model to test with', default=default.e2e_prefix, type=str)
parser.add_argument('--epoch', help='model to test with', default=default.e2e_epoch, type=int)
parser.add_argument('--gpu', help='GPU device to test with', default=0, type=int)
# rcnn
parser.add_argument('--vis', help='turn on visualization', action='store_true')
parser.add_argument('--thresh', help='valid detection threshold', default=1e-3, type=float)
parser.add_argument('--shuffle', help='shuffle data on visualization', action='store_true')
parser.add_argument('--has_rpn', help='generate proposals on the fly', action='store_true', default=True)
parser.add_argument('--proposal', help='can be ss for selective search or rpn', default='rpn', type=str)
args = parser.parse_args()
return args
def main():
args = parse_args()
logger.info('Called with argument: %s' % args)
ctx = mx.gpu(args.gpu)
test_rcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
ctx, args.prefix, args.epoch,
args.vis, args.shuffle, args.has_rpn, args.proposal, args.thresh)
if __name__ == '__main__':
main()