blob: 21ebb0f96bbc646731ef0f89f60cb05d59a000de [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import pytest
import tarfile
import numpy as np
from PIL import Image
from tvm.driver import tvmc
from tvm.contrib.download import download_testdata
# Support functions
def download_and_untar(model_url, model_sub_path, temp_dir):
model_tar_name = os.path.basename(model_url)
model_path = download_testdata(model_url, model_tar_name, module=["tvmc"])
if model_path.endswith("tgz") or model_path.endswith("gz"):
tar = tarfile.open(model_path)
tar.extractall(path=temp_dir)
tar.close()
return os.path.join(temp_dir, model_sub_path)
def get_sample_compiled_module(target_dir):
"""Support function that returns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"mobilenet_v1_1.0_224_quant.tflite",
temp_dir=target_dir,
)
return tvmc.compiler.compile_model(model_file, target="llvm")
# PyTest fixtures
@pytest.fixture(scope="session")
def tflite_mobilenet_v1_1_quant(tmpdir_factory):
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"mobilenet_v1_1.0_224_quant.tflite",
temp_dir=tmpdir_factory.mktemp("data"),
)
return model_file
@pytest.fixture(scope="session")
def pb_mobilenet_v1_1_quant(tmpdir_factory):
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"mobilenet_v1_1.0_224_frozen.pb",
temp_dir=tmpdir_factory.mktemp("data"),
)
return model_file
@pytest.fixture(scope="session")
def keras_resnet50(tmpdir_factory):
try:
from tensorflow.keras.applications.resnet50 import ResNet50
except ImportError:
# not all environments provide TensorFlow, so skip this fixture
# if that is that case.
return ""
model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5")
model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000)
model.save(model_file_name)
return model_file_name
@pytest.fixture(scope="session")
def onnx_resnet50():
base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model"
file_to_download = "resnet50-v2-7.onnx"
model_file = download_testdata(
"{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"]
)
return model_file
@pytest.fixture(scope="session")
def tflite_compiled_module_as_tarfile(tmpdir_factory):
# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""
target_dir = tmpdir_factory.mktemp("data")
graph, lib, params, _ = get_sample_compiled_module(target_dir)
module_file = os.path.join(target_dir, "mock.tar")
tvmc.compiler.save_module(module_file, graph, lib, params)
return module_file
@pytest.fixture(scope="session")
def imagenet_cat(tmpdir_factory):
tmpdir_name = tmpdir_factory.mktemp("data")
cat_file_name = "imagenet_cat.npz"
cat_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
image_path = download_testdata(cat_url, "inputs", module=["tvmc"])
resized_image = Image.open(image_path).resize((224, 224))
image_data = np.asarray(resized_image).astype("float32")
image_data = np.expand_dims(image_data, axis=0)
cat_file_full_path = os.path.join(tmpdir_name, cat_file_name)
np.savez(cat_file_full_path, input=image_data)
return cat_file_full_path