blob: 45bd9edf3ad5254bb038f39199a71f3e7d9897a8 [file] [log] [blame]
"""
ResNet Inference Example
========================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.
"""
######################################################################
# Import Libraries
# ----------------
# We start by importing the tvm, vta, nnvm libraries to run this example.
from __future__ import absolute_import, print_function
import os
import time
from io import BytesIO
import numpy as np
import requests
from matplotlib import pyplot as plt
from PIL import Image
import tvm
from tvm import rpc, autotvm
from tvm.contrib import graph_runtime, util
from tvm.contrib.download import download
import nnvm.compiler
import vta
import vta.testing
# Load VTA parameters from the vta/config/vta_config.json file
env = vta.get_env()
# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
w, h = image.size
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
image = image.crop(crop)
image = image.resize((224, 224))
return image
# Helper function to read in image
# Takes in Image object, returns an ND array
def process_image(image):
# Convert to neural network input format
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return tvm.nd.array(image.astype("float32"))
# Classification helper function
# Takes in the graph runtime, and an image, and returns top result and time
def classify(m, image):
m.set_input('data', image)
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
tvm_output = m.get_output(0)
top = np.argmax(tvm_output.asnumpy()[0])
tcost = "t={0:.2f}s".format(tcost.mean)
return tcost + " {}".format(synset[top])
# Helper function to compile the NNVM graph
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def generate_graph(graph_fn, params_fn, device="vta"):
# Measure build start time
build_start = time.time()
# Derive the TVM target
target = tvm.target.create("llvm -device={}".format(device))
# Derive the LLVM compiler flags
# When targetting the Pynq, cross-compile to ARMv7 ISA
if env.TARGET == "sim":
target_host = "llvm"
elif env.TARGET == "pynq":
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
# Load the ResNet-18 graph and parameters
sym = nnvm.graph.load_json(open(graph_fn).read())
params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read())
# Populate the shape and data type dictionary
shape_dict = {"data": (1, 3, 224, 224)}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if target.device_name == "vta":
assert env.BLOCK_IN == env.BLOCK_OUT
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
# Compile NNVM graph
with nnvm.compiler.build_config(opt_level=3):
if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
# Save the compiled inference graph library
assert tvm.module.enabled("rpc")
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
# Send the inference library over to the remote RPC server
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
# Measure build time
build_time = time.time() - build_start
print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time))
return graph, lib, params
######################################################################
# Download ResNet Model
# --------------------------------------------
# Download the necessary files to run ResNet-18.
#
# Obtain ResNet model and download them into _data dir
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = 'synset.txt'
graph_fn = 'resnet18_qt8.json'
params_fn = 'resnet18_qt8.params'
# Create data dir
data_dir = "_data/"
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download files
for file in [categ_fn, graph_fn, params_fn]:
download(os.path.join(url, file), os.path.join(data_dir, file))
# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())
# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
autotvm.tophub.check_backend('vta')
######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
# Build the RPC server's VTA runtime and program the Pynq FPGA.
# Measure build start time
reconfig_start = time.time()
# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))
# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
# Reconfigure the JIT runtime
vta.reconfig_runtime(remote)
# Program the FPGA with a pre-compiled VTA bitstream.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
vta.program_fpga(remote, bitstream=None)
# Report on reconfiguration time
reconfig_time = time.time() - reconfig_start
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
remote = rpc.LocalSession()
######################################################################
# Build the ResNet Runtime
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.
# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
# Device context
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
# Build the graph runtime
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
os.path.join(data_dir, params_fn),
device)
m = graph_runtime.create(graph, lib, ctx)
# Set the parameters
m.set_input(**params)
######################################################################
# Run ResNet-18 inference on a sample image
# -----------------------------------------
# Perform image classification on test image.
# You can change the test image URL to any image of your choosing.
# Read in test image
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
# Read in test image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).resize((224, 224))
# Show Image
plt.imshow(image)
plt.show()
# Set the input
image = process_image(image)
m.set_input('data', image)
# Perform inference
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
# Get classification results
tvm_output = m.get_output(0)
top_categories = np.argsort(tvm_output.asnumpy()[0])
# Report top-5 classification results
print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
print(" #2:", synset[top_categories[-2]])
print(" #3:", synset[top_categories[-3]])
print(" #4:", synset[top_categories[-4]])
print(" #5:", synset[top_categories[-5]])
print("Performed inference in {0:.2f}s".format(tcost.mean))
######################################################################
# Run a Youtube Video Image Classifier
# ------------------------------------
# Perform image classification on test stream on 1 frame every 48 frames.
# Comment the `if False:` out to run the demo
# Early exit - remove for Demo
if False:
import cv2
import pafy
from IPython.display import clear_output
# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
w, h = image.size
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
image = image.crop(crop)
image = image.resize((224, 224))
return image
# 16:16 inches
plt.rcParams['figure.figsize'] = [16, 16]
# Stream the video in
url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
video = pafy.new(url)
best = video.getbest(preftype="mp4")
cap = cv2.VideoCapture(best.url)
# Process one frame out of every 48 for variety
count = 0
guess = ""
while(count<2400):
# Capture frame-by-frame
ret, frame = cap.read()
# Process one every 48 frames
if count % 48 == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
# Crop and resize
thumb = np.array(thumbnailify(frame))
image = process_image(thumb)
guess = classify(m, image)
# Insert guess in frame
frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)
plt.imshow(thumb)
plt.axis('off')
plt.show()
if cv2.waitKey(1) & 0xFF == ord('q'):
break
clear_output(wait=True)
count += 1
# When everything done, release the capture
cap.release()
cv2.destroyAllWindows()