blob: 5586e5ba688b94d8ad0986ce7de1436f1267b321 [file] [log] [blame]
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from PIL import Image
from tvm.contrib import graph_executor
lib_path = "./outlib/graph.o.so"
param_path = "./outlib/graph.params"
json_path = "./outlib/graph.json"
img_path = "./data/img_10.jpg"
loaded_lib = tvm.runtime.load_module(lib_path)
print(loaded_lib)
dev = tvm.runtime.cpu()
with open(json_path) as jsonfile, open(param_path, "rb") as paramfile:
module = graph_executor.create(jsonfile.read(), loaded_lib, dev)
loaded_param = bytearray(paramfile.read())
module.load_params(loaded_param)
# Resize it to 28X28
resized_image = Image.open(img_path).resize((28, 28))
img_data = np.asarray(resized_image).astype("float32") / 255
img_data = np.reshape(img_data, (1, 1, 28, 28))
print(loaded_lib)
module.set_input("Input3", img_data)
module.run()
output_shape = (1, 10)
tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
print(tvm_output)