| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import subprocess |
| import os |
| import errno |
| |
| import mxnet as mx |
| |
| def download_file(url, local_fname=None, force_write=False): |
| # requests is not default installed |
| import requests |
| if local_fname is None: |
| local_fname = url.split('/')[-1] |
| if not force_write and os.path.exists(local_fname): |
| return local_fname |
| |
| dir_name = os.path.dirname(local_fname) |
| |
| if dir_name != "": |
| if not os.path.exists(dir_name): |
| try: # try to create the directory if it doesn't exists |
| os.makedirs(dir_name) |
| except OSError as exc: |
| if exc.errno != errno.EEXIST: |
| raise |
| |
| r = requests.get(url, stream=True) |
| assert r.status_code == 200, "failed to open %s" % url |
| with open(local_fname, 'wb') as f: |
| for chunk in r.iter_content(chunk_size=1024): |
| if chunk: # filter out keep-alive new chunks |
| f.write(chunk) |
| return local_fname |
| |
| def get_gpus(): |
| """ |
| return a list of GPUs |
| """ |
| return range(mx.util.get_gpu_count()) |