blob: 601645e844c395e665df7c02642fd15300f97e5e [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
# Copyright [2020] [Apache Software Foundation]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data Module.
"""
import os
import requests
import progressbar
# Use six to create code compatible with Python 2 and 3.
# See http://pythonhosted.org/six/
from .utils import check_path
from .exceptions import InvalidConfigException
from .log import get_logger
logger = get_logger('common.data')
class MarvinData():
_key = 'MARVIN_DATA_PATH'
@classmethod
def get_data_path(cls):
"""
Read data path from the following sources in order of priority:
1. Environment variable
If not found raises an exception
:return: str - datapath
"""
marvin_path = os.environ.get(cls._key)
if not marvin_path:
raise InvalidConfigException('Data path not set!')
is_path_created = check_path(marvin_path, create=True)
if not is_path_created:
raise InvalidConfigException('Data path does not exist!')
return marvin_path
@classmethod
def _convert_path_to_key(cls, path):
if path.startswith(os.path.sep):
path = os.path.relpath(path, start=cls.get_data_path())
return '/'.join(path.split(os.path.sep))
@classmethod
def load_data(cls, relpath):
"""
Load data from the following sources in order of priority:
1. Filesystem
:param relpath: path relative to "data_path"
:return: str - data content
"""
filepath = os.path.join(cls.get_data_path(), relpath)
with open(filepath) as fp:
content = fp.read()
return content
@classmethod
def download_file(cls, url, local_file_name=None, force=False, chunk_size=1024):
"""
Download file from a given url
"""
local_file_name = local_file_name if local_file_name else url.split(
'/')[-1]
filepath = os.path.join(cls.get_data_path(), local_file_name)
if not os.path.exists(filepath) or force:
try:
headers = requests.head(url, allow_redirects=True).headers
length = headers.get('Content-Length')
logger.info(
"Starting download of {} file with {} bytes ...".format(url, length))
widgets = [
'Downloading file please wait...', progressbar.Percentage(),
' ', progressbar.Bar(),
' ', progressbar.ETA(),
' ', progressbar.FileTransferSpeed(),
]
bar = progressbar.ProgressBar(
widgets=widgets, max_value=int(length) + chunk_size).start()
r = requests.get(url, stream=True)
with open(filepath, 'wb') as f:
total_chunk = 0
for chunk in r.iter_content(chunk_size):
if chunk:
f.write(chunk)
total_chunk += chunk_size
bar.update(total_chunk)
bar.finish()
except:
if os.path.exists(filepath):
os.remove(filepath)
raise
return filepath