blob: 81c124cffdd7af0a5cdcf85e76490bb78b6e11e0 [file] [log] [blame]
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "topic/vta/tutorials/frontend/deploy_detection.py"
.. only:: html
.. note::
:class: sphx-glr-download-link-note
This tutorial can be used interactively with Google Colab! You can also click
:ref:`here <sphx_glr_download_topic_vta_tutorials_frontend_deploy_detection.py>` to run the Jupyter notebook locally.
.. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg
:align: center
:target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/66e1a42229aae7ed49ac268f520e6727/deploy_detection.ipynb
:width: 300px
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_topic_vta_tutorials_frontend_deploy_detection.py:
Deploy Pretrained Vision Detection Model from Darknet on VTA
============================================================
**Author**: `Hua Jiang <https://github.com/huajsj>`_
This tutorial provides an end-to-end demo, on how to run Darknet YoloV3-tiny
inference onto the VTA accelerator design to perform Image detection tasks.
It showcases Relay as a front end compiler that can perform quantization (VTA
only supports int8/32 inference) as well as graph packing (in order to enable
tensorization in the core) to massage the compute graph for the hardware target.
.. GENERATED FROM PYTHON SOURCE LINES 30-48
Install dependencies
--------------------
To use the autotvm package in tvm, we need to install some extra dependencies.
(change "3" to "2" if you use python2):
.. code-block:: bash
pip3 install "Pillow<7"
YOLO-V3-tiny Model with Darknet parsing have dependancy with CFFI and CV2 library,
we need to install CFFI and CV2 before executing this script.
.. code-block:: bash
pip3 install cffi
pip3 install opencv-python
Now return to the python code. Import packages.
.. GENERATED FROM PYTHON SOURCE LINES 48-69
.. code-block:: default
from __future__ import absolute_import, print_function
import sys
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import tvm
import vta
from tvm import rpc, autotvm, relay
from tvm.relay.testing import yolo_detection, darknet
from tvm.relay.testing.darknet import __darknetffi__
from tvm.contrib import graph_executor, utils
from tvm.contrib.download import download_testdata
from vta.testing import simulator
from vta.top import graph_pack
# Make sure that TVM was compiled with RPC=1
assert tvm.runtime.enabled("rpc")
.. GENERATED FROM PYTHON SOURCE LINES 70-73
Download yolo net configure file, weight file, darknet library file based on
Model Name
----------------------------------------------------------------------------
.. GENERATED FROM PYTHON SOURCE LINES 73-100
.. code-block:: default
MODEL_NAME = "yolov3-tiny"
REPO_URL = "https://github.com/dmlc/web-data/blob/main/darknet/"
cfg_path = download_testdata(
"https://github.com/pjreddie/darknet/blob/master/cfg/" + MODEL_NAME + ".cfg" + "?raw=true",
MODEL_NAME + ".cfg",
module="darknet",
)
weights_path = download_testdata(
"https://pjreddie.com/media/files/" + MODEL_NAME + ".weights" + "?raw=true",
MODEL_NAME + ".weights",
module="darknet",
)
if sys.platform in ["linux", "linux2"]:
darknet_lib_path = download_testdata(
REPO_URL + "lib/" + "libdarknet2.0.so" + "?raw=true", "libdarknet2.0.so", module="darknet"
)
elif sys.platform == "darwin":
darknet_lib_path = download_testdata(
REPO_URL + "lib_osx/" + "libdarknet_mac2.0.so" + "?raw=true",
"libdarknet_mac2.0.so",
module="darknet",
)
else:
raise NotImplementedError("Darknet lib is not supported on {} platform".format(sys.platform))
.. GENERATED FROM PYTHON SOURCE LINES 101-103
Download yolo categories and illustration front.
------------------------------------------------
.. GENERATED FROM PYTHON SOURCE LINES 103-113
.. code-block:: default
coco_path = download_testdata(
REPO_URL + "data/" + "coco.names" + "?raw=true", "coco.names", module="data"
)
font_path = download_testdata(
REPO_URL + "data/" + "arial.ttf" + "?raw=true", "arial.ttf", module="data"
)
with open(coco_path) as f:
content = f.readlines()
names = [x.strip() for x in content]
.. GENERATED FROM PYTHON SOURCE LINES 114-117
Define the platform and model targets.
--------------------------------------
Execute on CPU vs. VTA, and define the model.
.. GENERATED FROM PYTHON SOURCE LINES 117-141
.. code-block:: default
# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()
# Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu
pack_dict = {
"yolov3-tiny": ["nn.max_pool2d", "cast", 4, 186],
}
# Name of Darknet model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
# the number 4 indicate the ``start_pack`` index is 4, the
# number 186 indicate the ``stop_pack index`` is 186, by using
# name and index number, here we can located to correct place
# where to start/end when there are multiple ``nn.max_pool2d``
# or ``cast``, print(mod.astext(show_meta_data=False)) can help
# to find operator name and index information.
assert MODEL_NAME in pack_dict
.. GENERATED FROM PYTHON SOURCE LINES 142-146
Obtain an execution remote.
---------------------------
When target is 'pynq' or other FPGA backend, reconfigure FPGA and runtime.
Otherwise, if target is 'sim', execute locally.
.. GENERATED FROM PYTHON SOURCE LINES 146-180
.. code-block:: default
if env.TARGET not in ["sim", "tsim"]:
# Get remote from tracker node if environment variable is set.
# To set up the tracker, you'll need to follow the "Auto-tuning
# a convolutional network for VTA" tutorial.
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
# Otherwise if you have a device you want to program directly from
# the host, make sure you've set the variables below to the IP of
# your board.
device_host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
device_port = os.environ.get("VTA_RPC_PORT", "9091")
if not tracker_host or not tracker_port:
remote = rpc.connect(device_host, int(device_port))
else:
remote = autotvm.measure.request_remote(
env.TARGET, tracker_host, int(tracker_port), timeout=10000
)
# Reconfigure the JIT runtime and FPGA.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
reconfig_start = time.time()
vta.reconfig_runtime(remote)
vta.program_fpga(remote, bitstream=None)
reconfig_time = time.time() - reconfig_start
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
# In simulation mode, host the RPC server locally.
else:
remote = rpc.LocalSession()
# Get execution context from remote
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
.. GENERATED FROM PYTHON SOURCE LINES 181-194
Build the inference graph executor.
-----------------------------------
Using Darknet library load downloaded vision model and compile with Relay.
The compilation steps are:
1. Front end translation from Darknet into Relay module.
2. Apply 8-bit quantization: here we skip the first conv layer,
and dense layer which will both be executed in fp32 on the CPU.
3. Perform graph packing to alter the data layout for tensorization.
4. Perform constant folding to reduce number of operators (e.g. eliminate batch norm multiply).
5. Perform relay build to object file.
6. Load the object file onto remote (FPGA device).
7. Generate graph executor, `m`.
.. GENERATED FROM PYTHON SOURCE LINES 194-253
.. code-block:: default
# Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target):
net = __darknetffi__.dlopen(darknet_lib_path).load_network(
cfg_path.encode("utf-8"), weights_path.encode("utf-8"), 0
)
dshape = (env.BATCH, net.c, net.h, net.w)
dtype = "float32"
# Measure build start time
build_start = time.time()
# Start front end compilation
mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=dshape)
if target.device_name == "vta":
# Perform quantization in Relay
# Note: We set opt_level to 3 in order to fold batch norm
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
global_scale=23.0,
skip_conv_layers=[0],
store_lowbit_output=True,
round_for_shift=True,
):
mod = relay.quantize.quantize(mod, params=params)
# Perform graph packing and constant folding for VTA target
mod = graph_pack(
mod["main"],
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=pack_dict[MODEL_NAME][0],
stop_name=pack_dict[MODEL_NAME][1],
start_name_idx=pack_dict[MODEL_NAME][2],
stop_name_idx=pack_dict[MODEL_NAME][3],
)
else:
mod = mod["main"]
# Compile Relay program with AlterOpLayout disabled
with vta.build_config(disabled_pass={"AlterOpLayout", "tir.CommonSubexprElimTIR"}):
lib = relay.build(
mod, target=tvm.target.Target(target, host=env.target_host), params=params
)
# Measure Relay build time
build_time = time.time() - build_start
print(MODEL_NAME + " inference graph built in {0:.2f}s!".format(build_time))
# Send the inference library over to the remote RPC server
temp = utils.tempdir()
lib.export_library(temp.relpath("graphlib.tar"))
remote.upload(temp.relpath("graphlib.tar"))
lib = remote.load_module("graphlib.tar")
# Graph executor
m = graph_executor.GraphModule(lib["default"](ctx))
.. rst-class:: sphx-glr-script-out
.. code-block:: none
/workspace/python/tvm/relay/build_module.py:345: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)
warnings.warn(
yolov3-tiny inference graph built in 30.60s!
.. GENERATED FROM PYTHON SOURCE LINES 254-258
Perform image detection inference.
----------------------------------
We run detect on an downloaded image
Download test image
.. GENERATED FROM PYTHON SOURCE LINES 258-323
.. code-block:: default
[neth, netw] = dshape[2:]
test_image = "person.jpg"
img_url = REPO_URL + "data/" + test_image + "?raw=true"
img_path = download_testdata(img_url, test_image, "data")
data = darknet.load_image(img_path, neth, netw).transpose(1, 2, 0)
# Prepare test image for inference
plt.imshow(data)
plt.show()
data = data.transpose((2, 0, 1))
data = data[np.newaxis, :]
data = np.repeat(data, env.BATCH, axis=0)
# Set the network parameters and inputs
m.set_input("data", data)
# Perform inference and gather execution statistics
# More on: :py:method:`tvm.runtime.Module.time_evaluator`
num = 4 # number of times we run module for a single measurement
rep = 3 # number of measurements (we derive std dev from this)
timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
timer()
sim_stats = simulator.stats()
print("\nExecution statistics:")
for k, v in sim_stats.items():
# Since we execute the workload many times, we need to normalize stats
# Note that there is always one warm up run
# Therefore we divide the overall stats by (num * rep + 1)
print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
else:
tcost = timer()
std = np.std(tcost.results) * 1000
mean = tcost.mean * 1000
print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
print("Average per sample inference time: %.2fms" % (mean / env.BATCH))
# Get detection results from out
thresh = 0.5
nms_thresh = 0.45
tvm_out = []
for i in range(2):
layer_out = {}
layer_out["type"] = "Yolo"
# Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
layer_attr = m.get_output(i * 4 + 3).numpy()
layer_out["biases"] = m.get_output(i * 4 + 2).numpy()
layer_out["mask"] = m.get_output(i * 4 + 1).numpy()
out_shape = (layer_attr[0], layer_attr[1] // layer_attr[0], layer_attr[2], layer_attr[3])
layer_out["output"] = m.get_output(i * 4).numpy().reshape(out_shape)
layer_out["classes"] = layer_attr[4]
tvm_out.append(layer_out)
thresh = 0.560
# Show detection results
img = darknet.load_image_color(img_path)
_, im_h, im_w = img.shape
dets = yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out)
last_layer = net.layers[net.n - 1]
yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh)
yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
plt.imshow(img.transpose(1, 2, 0))
plt.show()
.. image-sg:: /topic/vta/tutorials/frontend/images/sphx_glr_deploy_detection_001.png
:alt: deploy detection
:srcset: /topic/vta/tutorials/frontend/images/sphx_glr_deploy_detection_001.png
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Execution statistics:
inp_load_nbytes : 25462784
wgt_load_nbytes : 17558016
acc_load_nbytes : 96128
uop_load_nbytes : 5024
out_store_nbytes: 3396224
gemm_counter : 10578048
alu_counter : 849056
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 4.032 seconds)
.. _sphx_glr_download_topic_vta_tutorials_frontend_deploy_detection.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: deploy_detection.py <deploy_detection.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: deploy_detection.ipynb <deploy_detection.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_