blob: 96030245b9e5fb0d92c698ef36c2904d30fbd35b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Helper utility for downloading"""
import os
import sys
import time
import uuid
import shutil
def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Parameters
----------
url : str
Download url.
path : str
Local file path to save downloaded file
overwrite : bool, optional
Whether to overwrite existing file
size_compare : bool, optional
Whether to do size compare to check downloaded file.
verbose: int, optional
Verbose level
retries: int, optional
Number of time to retry download, default at 3.
"""
# pylint: disable=import-outside-toplevel
import urllib.request as urllib2
if os.path.isfile(path) and not overwrite:
if size_compare:
import requests
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if "Content-Length" not in res_head.headers:
res_get = urllib2.urlopen(url)
url_file_size = int(res_get.headers["Content-Length"])
if url_file_size != file_size:
print("exist file got corrupted, downloading %s file freshly..." % path)
download(url, path, True, False)
return
print("File {} exists, skip.".format(path))
return
if verbose >= 1:
print("Downloading from url {} to {}".format(url, path))
# Stateful start time
start_time = time.time()
dirpath = os.path.dirname(path)
if dirpath and not os.path.isdir(dirpath):
os.makedirs(dirpath)
random_uuid = str(uuid.uuid4())
tempfile = os.path.join(dirpath, random_uuid)
def _download_progress(count, block_size, total_size):
# pylint: disable=unused-argument
"""Show the download progress."""
if count == 0:
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write(
"\r...%d%%, %.2f MB, %d KB/s, %d seconds passed"
% (percent, progress_size / (1024.0 * 1024), speed, duration)
)
sys.stdout.flush()
while retries >= 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if sys.version_info >= (3,):
urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(tempfile, "wb") as code:
code.write(data)
shutil.move(tempfile, path)
break
except Exception as err:
retries -= 1
if retries == 0:
if os.path.exists(tempfile):
os.remove(tempfile)
raise err
print(
"download failed due to {}, retrying, {} attempt{} left".format(
repr(err), retries, "s" if retries > 1 else ""
)
)
if "TEST_DATA_ROOT_PATH" in os.environ:
TEST_DATA_ROOT_PATH = os.environ.get("TEST_DATA_ROOT_PATH")
else:
TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".tvm_test_data")
os.makedirs(TEST_DATA_ROOT_PATH, exist_ok=True)
def download_testdata(url, relpath, module=None):
"""Downloads the test data from the internet.
Parameters
----------
url : str
Download url.
relpath : str
Relative file path.
module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.
Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global TEST_DATA_ROOT_PATH
if module is None:
module_path = ""
elif isinstance(module, str):
module_path = module
elif isinstance(module, (list, tuple)):
module_path = os.path.join(*module)
else:
raise ValueError("Unsupported module: " + module)
abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath)
download(url, abspath, overwrite=False, size_compare=False)
return abspath