blob: a0f16f67408ada433cc3a5e5662e27b51937b14e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import logging
import os
import time
import mxnet as mx
from common import modelzoo
from mxnet import nd
from mxnet.contrib.quantization import *
from mxnet.base import _LIB
def download_dataset(dataset_url, dataset_dir, logger=None):
if logger is not None:
logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir))
mx.test_utils.download(dataset_url, dataset_dir)
def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model %s... into path %s' % (model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
def advance_data_iter(data_iter, n):
assert n >= 0
if n == 0:
return data_iter
has_next_batch = True
while has_next_batch:
try:
data_iter.next()
n -= 1
if n == 0:
return data_iter
except StopIteration:
has_next_batch = False
def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None):
metrics = [mx.metric.create('acc'),
mx.metric.create('top_k_accuracy', top_k=5)]
if not isinstance(metrics, list):
metrics = [metrics, ]
mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
mod.set_params(arg_params, aux_params)
tic = time.time()
num = 0
for batch in data:
mod.forward(batch, is_train=False)
for m in metrics:
mod.update_metric(m, batch.label)
num += batch_size
if max_num_examples is not None and num >= max_num_examples:
break
speed = num / (time.time() - tic)
if logger is not None:
logger.info('Finished inference with %d images' % num)
logger.info('Finished with %f images per second', speed)
for m in metrics:
logger.info(m.get())
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Score a model on a dataset')
parser.add_argument('--model', type=str, required=True,
choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
parser.add_argument('--dataset', type=str, required=True, help='dataset path')
parser.add_argument('--rgb-mean', type=str, default='0,0,0')
parser.add_argument('--image-shape', type=str, default='3,224,224')
parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding')
parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference')
parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference')
parser.add_argument('--shuffle-dataset', action='store_true', default=True,
help='shuffle the calibration dataset')
parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
help='shuffling chunk seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--shuffle-seed', type=int, default=48564309,
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
parser.add_argument('--subgraph-backend', type=str, default='default', help='subgraph backend name.')
parser.add_argument('--ctx', type=str, default='cpu')
args = parser.parse_args()
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
data_nthreads = args.data_nthreads
batch_size = args.batch_size
logger.info('batch size = %d for inference' % batch_size)
rgb_mean = args.rgb_mean
logger.info('rgb_mean = %s' % rgb_mean)
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
label_name = args.label_name
logger.info('label_name = %s' % label_name)
image_shape = args.image_shape
data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))
dataset = args.dataset
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
logger.info('Dataset for inference: %s' % dataset)
subgraph_backend = args.subgraph_backend
if args.ctx == 'cpu':
ctx = mx.cpu()
elif args.ctx == 'gpu':
ctx = mx.gpu(0)
else:
raise ValueError('unknown ctx option, only cpu and gpu are supported')
# creating data iterator
data = mx.io.ImageRecordIter(path_imgrec=dataset,
label_width=1,
preprocess_threads=data_nthreads,
batch_size=batch_size,
data_shape=data_shape,
label_name=label_name,
rand_crop=False,
rand_mirror=False,
shuffle=True,
shuffle_chunk_seed=3982304,
seed=48564309,
**mean_args)
# download model
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
if subgraph_backend is not None:
os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
if subgraph_backend == 'default':
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
# make sure that fp32 inference works on the same images as calibrated quantized model
logger.info('Skipping the first %d batches' % args.num_skipped_batches)
data = advance_data_iter(data, args.num_skipped_batches)
num_inference_images = args.num_inference_batches * batch_size
logger.info('Running model %s for inference' % args.model)
score(sym, arg_params, aux_params, data, [ctx], label_name,
max_num_examples=num_inference_images, logger=logger)
if subgraph_backend is not None:
del os.environ['MXNET_SUBGRAPH_BACKEND']
if subgraph_backend == 'default':
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))