| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Helper utility for downloading""" |
| import os |
| import sys |
| import time |
| import uuid |
| import shutil |
| |
| |
| def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3): |
| """Downloads the file from the internet. |
| Set the input options correctly to overwrite or do the size comparison |
| |
| Parameters |
| ---------- |
| url : str |
| Download url. |
| |
| path : str |
| Local file path to save downloaded file |
| |
| overwrite : bool, optional |
| Whether to overwrite existing file |
| |
| size_compare : bool, optional |
| Whether to do size compare to check downloaded file. |
| |
| verbose: int, optional |
| Verbose level |
| |
| retries: int, optional |
| Number of time to retry download, default at 3. |
| """ |
| # pylint: disable=import-outside-toplevel |
| import urllib.request as urllib2 |
| |
| if os.path.isfile(path) and not overwrite: |
| if size_compare: |
| import requests |
| |
| file_size = os.path.getsize(path) |
| res_head = requests.head(url) |
| res_get = requests.get(url, stream=True) |
| if "Content-Length" not in res_head.headers: |
| res_get = urllib2.urlopen(url) |
| url_file_size = int(res_get.headers["Content-Length"]) |
| if url_file_size != file_size: |
| print("exist file got corrupted, downloading %s file freshly..." % path) |
| download(url, path, True, False) |
| return |
| print("File {} exists, skip.".format(path)) |
| return |
| |
| if verbose >= 1: |
| print("Downloading from url {} to {}".format(url, path)) |
| |
| # Stateful start time |
| start_time = time.time() |
| dirpath = os.path.dirname(path) |
| if dirpath and not os.path.isdir(dirpath): |
| os.makedirs(dirpath) |
| random_uuid = str(uuid.uuid4()) |
| tempfile = os.path.join(dirpath, random_uuid) |
| |
| def _download_progress(count, block_size, total_size): |
| # pylint: disable=unused-argument |
| """Show the download progress.""" |
| if count == 0: |
| return |
| duration = time.time() - start_time |
| progress_size = int(count * block_size) |
| speed = int(progress_size / (1024 * duration)) |
| percent = min(int(count * block_size * 100 / total_size), 100) |
| sys.stdout.write( |
| "\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" |
| % (percent, progress_size / (1024.0 * 1024), speed, duration) |
| ) |
| sys.stdout.flush() |
| |
| while retries >= 0: |
| # Disable pyling too broad Exception |
| # pylint: disable=W0703 |
| try: |
| if sys.version_info >= (3,): |
| urllib2.urlretrieve(url, tempfile, reporthook=_download_progress) |
| print("") |
| else: |
| f = urllib2.urlopen(url) |
| data = f.read() |
| with open(tempfile, "wb") as code: |
| code.write(data) |
| shutil.move(tempfile, path) |
| break |
| except Exception as err: |
| retries -= 1 |
| if retries == 0: |
| if os.path.exists(tempfile): |
| os.remove(tempfile) |
| raise err |
| print( |
| "download failed due to {}, retrying, {} attempt{} left".format( |
| repr(err), retries, "s" if retries > 1 else "" |
| ) |
| ) |
| |
| |
| if "TEST_DATA_ROOT_PATH" in os.environ: |
| TEST_DATA_ROOT_PATH = os.environ.get("TEST_DATA_ROOT_PATH") |
| else: |
| TEST_DATA_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".tvm_test_data") |
| os.makedirs(TEST_DATA_ROOT_PATH, exist_ok=True) |
| |
| |
| def download_testdata(url, relpath, module=None): |
| """Downloads the test data from the internet. |
| |
| Parameters |
| ---------- |
| url : str |
| Download url. |
| |
| relpath : str |
| Relative file path. |
| |
| module : Union[str, list, tuple], optional |
| Subdirectory paths under test data folder. |
| |
| Returns |
| ------- |
| abspath : str |
| Absolute file path of downloaded file |
| """ |
| global TEST_DATA_ROOT_PATH |
| if module is None: |
| module_path = "" |
| elif isinstance(module, str): |
| module_path = module |
| elif isinstance(module, (list, tuple)): |
| module_path = os.path.join(*module) |
| else: |
| raise ValueError("Unsupported module: " + module) |
| abspath = os.path.join(TEST_DATA_ROOT_PATH, module_path, relpath) |
| download(url, abspath, overwrite=False, size_compare=False) |
| return abspath |