blob: 24efd68ca4ca426d40b873717d7f11a22add7ae5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import os
import warnings
import glob
import shutil
import multiprocessing as mp
try:
from unittest import mock
except ImportError:
import mock
import mxnet as mx
import requests
import pytest
class MockResponse(requests.Response):
def __init__(self, status_code, content):
super(MockResponse, self).__init__()
assert isinstance(status_code, int)
self.status_code = status_code
self.raw = io.BytesIO(content.encode('utf-8'))
@mock.patch(
'requests.get', mock.Mock(side_effect=requests.exceptions.ConnectionError))
def test_download_retries_error():
with pytest.raises(Exception):
mx.gluon.utils.download("http://doesnotexist.notfound")
@mock.patch(
'requests.get',
mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
def _download_successful(tmp):
""" internal use for testing download successfully """
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmp)
def test_download_successful(tmpdir):
""" test download with one process """
tmp = str(tmpdir)
tmpfile = os.path.join(tmp, 'README.md')
_download_successful(tmpfile)
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
pattern = os.path.join(tmp, 'README.md*')
# check only one file we want left
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)
def test_multiprocessing_download_successful(tmpdir):
""" test download with multiprocessing """
tmp = str(tmpdir)
tmpfile = os.path.join(tmp, 'README.md')
process_list = []
# test it with 10 processes
for i in range(10):
process_list.append(mp.Process(
target=_download_successful, args=(tmpfile,)))
process_list[i].start()
for i in range(10):
process_list[i].join()
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
# check only one file we want left
pattern = os.path.join(tmp, 'README.md*')
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)
@mock.patch(
'requests.get',
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
""" test download verify_ssl parameter """
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.apache.org/index.html", verify_ssl=False)
assert any(
str(w.message).startswith('Unverified HTTPS request')
for w in warnings_)