blob: 03478dc41e33c1104511184a6682097d75b7cc13 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name
# pylint: disable=arguments-differ, unused-argument
"""Unit tests for various models and operators"""
import os
import numpy as np
import oneflow as flow
from flowvision.models.alexnet import alexnet
from flowvision.models.squeezenet import squeezenet1_0
from flowvision.models.shufflenet_v2 import shufflenet_v2_x0_5
from flowvision.models.mobilenet import mobilenet_v2
from flowvision.models.ghostnet import ghostnet
from flowvision.models.vision_transformer import vit_base_patch16_224
import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relay
MODEL_HOME = "test_model"
def mkdir(path):
# init
path = path.strip()
path = path.rstrip("\\")
if not os.path.exists(path):
os.makedirs(path)
else:
print(f"{path} is already here")
def rmdir(path):
for root, dirs, files in os.walk(path, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
os.removedirs(path)
def assert_shape(out1, out2):
if out1.shape != out2.shape:
msg = "Output shapes {} and {} don't match"
raise AssertionError(msg.format(out1.shape, out2.shape))
class OneFlowGraph(flow.nn.Graph):
def __init__(self, module):
super().__init__()
self.m = module
def build(self, x):
out = self.m(x)
return out
def get_oneflow_output(model, inputs):
flow_output = model(inputs)
return flow_output.numpy()
def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"):
"""Generic function to execute and get tvm output"""
inputs_numpy = inputs.numpy()
if target == "llvm":
device = tvm.cpu(0)
elif target == "cuda":
device = tvm.cuda(0)
mod, params = relay.frontend.from_oneflow(graph, model_path)
with tvm.transform.PassContext(opt_level=10):
intrp = relay.build_module.create_executor("graph", mod, device, target)
tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)), **params).numpy()
return tvm_output
def verify_model(
model,
name="",
rtol=1e-5,
atol=1e-5,
inputs=flow.tensor(
np.random.rand(1, 3, 224, 224),
dtype=flow.float32,
),
device="llvm",
):
"""Generic function to generate and compare oneflow and TVM output"""
if device == "cuda":
model.to(device)
inputs = inputs.to(device)
graph = OneFlowGraph(model)
graph._compile(inputs)
mkdir(MODEL_HOME)
flow.save(model.state_dict(), MODEL_HOME)
out_flow = get_oneflow_output(graph, inputs)
out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
rmdir(MODEL_HOME)
assert_shape(out_flow, out_tvm)
tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
@tvm.testing.uses_gpu
def test_vision_models():
"""Vision models test"""
if os.path.exists(MODEL_HOME):
rmdir(MODEL_HOME)
vision_alexnet = alexnet().eval()
vision_squeezenet = squeezenet1_0().eval()
vision_shufflenet = shufflenet_v2_x0_5().eval()
vision_mobilenetv2 = mobilenet_v2().eval()
vision_ghostnet = ghostnet().eval()
vision_vit = vit_base_patch16_224().eval()
for device in ["llvm"]:
verify_model(vision_alexnet, device=device)
verify_model(vision_squeezenet, device=device)
verify_model(vision_shufflenet, device=device)
verify_model(vision_mobilenetv2, device=device)
verify_model(vision_ghostnet, device=device)
verify_model(vision_vit, device=device)
if __name__ == "__main__":
test_vision_models()
rmdir("log")