blob: 25020467e58dfcc05ecbc913aae76ef6d020a5c7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Python VTA Deploy."""
from __future__ import absolute_import, print_function
import os
from os.path import join
from io import BytesIO
from PIL import Image
import requests
import numpy as np
import tvm
from tvm.contrib import graph_runtime, download
CTX = tvm.ext_dev(0)
def load_vta_library():
"""load vta lib"""
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
vtadll = os.path.abspath(os.path.join(proj_root, "build/libvta.so"))
return tvm.runtime.load_module(vtadll)
def load_model():
""" Load VTA Model """
load_vta_library()
with open("./build/model/graph.json", "r") as graphfile:
graph = graphfile.read()
lib = tvm.runtime.load_module("./build/model/lib.so")
model = graph_runtime.create(graph, lib, CTX)
with open("./build/model/params.params", "rb") as paramfile:
param_bytes = paramfile.read()
categ_url = "https://github.com/uwsampl/web-data/raw/main/vta/models/"
categ_fn = "synset.txt"
download.download(join(categ_url, categ_fn), categ_fn)
synset = eval(open(categ_fn).read())
return model, param_bytes, synset
if __name__ == "__main__":
MOD, PARAMS_BYTES, SYNSET = load_model()
IMAGE_URL = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
RESPONSE = requests.get(IMAGE_URL)
# Prepare test image for inference
IMAGE = Image.open(BytesIO(RESPONSE.content)).resize((224, 224))
MOD.set_input('data', IMAGE)
MOD.load_params(PARAMS_BYTES)
MOD.run()
TVM_OUTPUT = MOD.get_output(0, tvm.nd.empty((1, 1000), "float32", CTX))
TOP_CATEGORIES = np.argsort(TVM_OUTPUT.asnumpy()[0])
print("\t#1:", SYNSET[TOP_CATEGORIES[-1]])