blob: e6c2949cbd5205187c58dc37a1dc97caa8dc2136 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import argparse
import requests
import errno
import os
models = ["imagenet1k-inception-bn", "imagenet1k-resnet-50",
"imagenet1k-resnet-152", "imagenet1k-resnet-18"]
def download(url, fname=None, dirname=None, overwrite=False, retries=5):
"""Download an given URL
Parameters
----------
url : str
URL to download
fname : str, optional
filename of the downloaded file. If None, then will guess a filename
from url.
dirname : str, optional
output directory name. If None, then guess from fname or use the current
directory
overwrite : bool, optional
Default is false, which means skipping download if the local file
exists. If true, then download the url to overwrite the local file if
exists.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
Returns
-------
str
The filename of the downloaded file
"""
assert retries >= 0, "Number of retries should be at least 0"
if fname is None:
fname = url.split('/')[-1]
if dirname is None:
dirname = os.path.dirname(fname)
else:
fname = os.path.join(dirname, fname)
if dirname != "":
if not os.path.exists(dirname):
try:
logging.info('create directory %s', dirname)
os.makedirs(dirname)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise OSError('failed to create ' + dirname)
if not overwrite and os.path.exists(fname):
logging.info("%s exists, skipping download", fname)
return fname
while retries+1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
r = requests.get(url, stream=True)
assert r.status_code == 200, f"failed to open {url}"
with open(fname, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
logging.info("downloaded %s into %s successfully", url, fname)
return fname
def download_model(model_name, dst_dir='./', meta_info=None):
"""Download a model from data.mxnet.io
Parameters
----------
model_name : str
Model name to download
dst_dir : str
Destination Directory to download the model
meta_info : dict of dict
Mapping from model_name to dict of the following structure:
{'symbol': url, 'params': url}
Returns
-------
Two element tuple containing model_name and epoch for the params saved
"""
_base_model_url = 'http://data.mxnet.io/models/'
_default_model_info = {
'imagenet1k-inception-bn': {'symbol':_base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
'params':_base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
'imagenet1k-resnet-18': {'symbol':_base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
'params':_base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
'imagenet1k-resnet-34': {'symbol':_base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
'params':_base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
'imagenet1k-resnet-50': {'symbol':_base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
'params':_base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
'imagenet1k-resnet-101': {'symbol':_base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
'params':_base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
'imagenet1k-resnet-152': {'symbol':_base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
'params':_base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
'imagenet1k-resnext-50': {'symbol':_base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
'params':_base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
'imagenet1k-resnext-101': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
'imagenet1k-resnext-101-64x4d':
{'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json',
'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'},
'imagenet11k-resnet-152':
{'symbol':_base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
'params':_base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
'imagenet11k-place365ch-resnet-152':
{'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
'params':_base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
'imagenet11k-place365ch-resnet-50':
{'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
'params':_base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
}
if meta_info is None:
meta_info = _default_model_info
meta_info = dict(meta_info)
if model_name not in meta_info:
return (None, 0)
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
meta = dict(meta_info[model_name])
assert 'symbol' in meta, "missing symbol url"
model_name = os.path.join(dst_dir, model_name)
download(meta['symbol'], model_name+'-symbol.json')
assert 'params' in meta, "mssing parameter file url"
download(meta['params'], model_name+'-0000.params')
download(_base_model_url + 'imagenet/synset.txt')
return (model_name, 0)
def main():
logging.basicConfig()
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser(description='Download model hybridize and save as symbolic model for multithreaded inference')
parser.add_argument("--model", type=str, choices=models, required=True)
args = parser.parse_args()
download_model(args.model)
if __name__ == "__main__":
main()