blob: eef9e0b06349fa34b5a429a72c356b1aff7f88ef [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tvm
from tvm import te
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tflite_runtime.interpreter as tflite
def skipped_test_tflite_runtime():
def get_tflite_model_path(target_edgetpu):
# Return a path to the model
edgetpu_path = os.getenv("EDGETPU_PATH", "/home/mendel/edgetpu")
# Obtain mobilenet model from the edgetpu repo path
if target_edgetpu:
model_path = os.path.join(
edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant_edgetpu.tflite"
)
else:
model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant.tflite")
return model_path
def init_interpreter(model_path, target_edgetpu):
# Initialize interpreter
if target_edgetpu:
edgetpu_path = os.getenv("EDGETPU_PATH", "/home/mendel/edgetpu")
libedgetpu = os.path.join(edgetpu_path, "libedgetpu/direct/aarch64/libedgetpu.so.1")
interpreter = tflite.Interpreter(
model_path=model_path, experimental_delegates=[tflite.load_delegate(libedgetpu)]
)
else:
interpreter = tflite.Interpreter(model_path=model_path)
return interpreter
def check_remote(target_edgetpu=False):
tflite_model_path = get_tflite_model_path(target_edgetpu)
# inference via tflite interpreter python apis
interpreter = init_interpreter(tflite_model_path, target_edgetpu)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]["shape"]
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
interpreter.set_tensor(input_details[0]["index"], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]["index"])
# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
with open(tflite_model_path, "rb") as model_fin:
runtime = tflite_runtime.create(model_fin.read(), ctx)
runtime.set_input(0, tvm.nd.array(tflite_input, ctx))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)
# Target CPU on coral board
check_remote()
# Target EdgeTPU on coral board
check_remote(target_edgetpu=True)
if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass