blob: b56ad2f66742fe2f5c5f99c9e9ca5e2b99cedf2d [file] [log] [blame]
.. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY
.. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE
.. CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "how_to/compile_models/from_tensorflow.py"
.. only:: html
.. note::
:class: sphx-glr-download-link-note
This tutorial can be used interactively with Google Colab! You can also click
:ref:`here <sphx_glr_download_how_to_compile_models_from_tensorflow.py>` to run the Jupyter notebook locally.
.. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/utilities/colab_button.svg
:align: center
:target: https://colab.research.google.com/github/apache/tvm-site/blob/asf-site/docs/_downloads/83e3b018e8bac8d31bb331d200a33a04/from_tensorflow.ipynb
:width: 300px
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_how_to_compile_models_from_tensorflow.py:
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with TVM.
For us to begin with, tensorflow python module is required to be installed.
.. code-block:: bash
%%shell
pip install tensorflow
Please refer to https://www.tensorflow.org/install
.. GENERATED FROM PYTHON SOURCE LINES 31-74
.. code-block:: default
# tvm, relay
import tvm
from tvm import te
from tvm import relay
# os and numpy
import numpy as np
import os.path
# Tensorflow imports
import tensorflow as tf
# Ask tensorflow to limit its GPU memory to what's actually needed
# instead of gobbling everything that's available.
# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
# This way this tutorial is a little more friendly to sphinx-gallery.
gpus = tf.config.list_physical_devices("GPU")
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
print("tensorflow will use experimental.set_memory_growth(True)")
except RuntimeError as e:
print("experimental.set_memory_growth option is not available: {}".format(e))
try:
tf_compat_v1 = tf.compat.v1
except ImportError:
tf_compat_v1 = tf
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
# Test image
img_name = "elephant-299.jpg"
image_url = os.path.join(repo_base, img_name)
.. GENERATED FROM PYTHON SOURCE LINES 75-79
Tutorials
---------
Please refer docs/frontend/tensorflow.md for more details for various models
from tensorflow.
.. GENERATED FROM PYTHON SOURCE LINES 79-100
.. code-block:: default
model_name = "classify_image_graph_def-with_shapes.pb"
model_url = os.path.join(repo_base, model_name)
# Image label map
map_proto = "imagenet_2012_challenge_label_map_proto.pbtxt"
map_proto_url = os.path.join(repo_base, map_proto)
# Human readable text for labels
label_map = "imagenet_synset_to_human_label_map.txt"
label_map_url = os.path.join(repo_base, label_map)
# Target settings
# Use these commented settings to build for cuda.
# target = tvm.target.Target("cuda", host="llvm")
# layout = "NCHW"
# dev = tvm.cuda(0)
target = tvm.target.Target("llvm", host="llvm")
layout = None
dev = tvm.cpu(0)
.. GENERATED FROM PYTHON SOURCE LINES 101-104
Download required files
-----------------------
Download files listed above.
.. GENERATED FROM PYTHON SOURCE LINES 104-111
.. code-block:: default
from tvm.contrib.download import download_testdata
img_path = download_testdata(image_url, img_name, module="data")
model_path = download_testdata(model_url, model_name, module=["tf", "InceptionV1"])
map_proto_path = download_testdata(map_proto_url, map_proto, module="data")
label_path = download_testdata(label_map_url, label_map, module="data")
.. GENERATED FROM PYTHON SOURCE LINES 112-115
Import model
------------
Creates tensorflow graph definition from protobuf file.
.. GENERATED FROM PYTHON SOURCE LINES 115-126
.. code-block:: default
with tf_compat_v1.gfile.GFile(model_path, "rb") as f:
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name="")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf_compat_v1.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, "softmax")
.. GENERATED FROM PYTHON SOURCE LINES 127-135
Decode image
------------
.. note::
tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
JpegDecode is bypassed (just return source node).
Hence we supply decoded frame to TVM instead.
.. GENERATED FROM PYTHON SOURCE LINES 135-142
.. code-block:: default
from PIL import Image
image = Image.open(img_path).resize((299, 299))
x = np.array(image)
.. GENERATED FROM PYTHON SOURCE LINES 143-150
Import the graph to Relay
-------------------------
Import tensorflow graph definition to relay frontend.
Results:
sym: relay expr for given tensorflow protobuf.
params: params converted from tensorflow params (tensor protobuf).
.. GENERATED FROM PYTHON SOURCE LINES 150-155
.. code-block:: default
shape_dict = {"DecodeJpeg/contents": x.shape}
dtype_dict = {"DecodeJpeg/contents": "uint8"}
mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
print("Tensorflow protobuf imported to relay frontend.")
.. rst-class:: sphx-glr-script-out
.. code-block:: none
/workspace/python/tvm/relay/frontend/tensorflow.py:537: UserWarning: Ignore the passed shape. Shape in graphdef will be used for operator DecodeJpeg/contents.
warnings.warn(
/workspace/python/tvm/relay/frontend/tensorflow_ops.py:1036: UserWarning: DecodeJpeg: It's a pass through, please handle preprocessing before input
warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
Tensorflow protobuf imported to relay frontend.
.. GENERATED FROM PYTHON SOURCE LINES 156-164
Relay Build
-----------
Compile the graph to llvm target with given input specification.
Results:
graph: Final graph after compilation.
params: final params after compilation.
lib: target library which can be deployed on target with TVM runtime.
.. GENERATED FROM PYTHON SOURCE LINES 164-168
.. code-block:: default
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target, params=params)
.. GENERATED FROM PYTHON SOURCE LINES 169-172
Execute the portable graph on TVM
---------------------------------
Now we can try deploying the compiled model on target.
.. GENERATED FROM PYTHON SOURCE LINES 172-184
.. code-block:: default
from tvm.contrib import graph_executor
dtype = "uint8"
m = graph_executor.GraphModule(lib["default"](dev))
# set inputs
m.set_input("DecodeJpeg/contents", tvm.nd.array(x.astype(dtype)))
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), "float32"))
.. GENERATED FROM PYTHON SOURCE LINES 185-188
Process the output
------------------
Process the model output to human readable text for InceptionV1.
.. GENERATED FROM PYTHON SOURCE LINES 188-201
.. code-block:: default
predictions = tvm_output.numpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path, uid_lookup_path=label_path)
# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print("%s (score = %.5f)" % (human_string, score))
.. rst-class:: sphx-glr-script-out
.. code-block:: none
African elephant, Loxodonta africana (score = 0.61481)
tusker (score = 0.30387)
Indian elephant, Elephas maximus (score = 0.03343)
banana (score = 0.00023)
rapeseed (score = 0.00021)
.. GENERATED FROM PYTHON SOURCE LINES 202-205
Inference on tensorflow
-----------------------
Run the corresponding model on tensorflow
.. GENERATED FROM PYTHON SOURCE LINES 205-258
.. code-block:: default
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf_compat_v1.gfile.GFile(model_path, "rb") as f:
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name="")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf_compat_v1.gfile.Exists(image):
tf.logging.fatal("File does not exist %s", image)
image_data = tf_compat_v1.gfile.GFile(image, "rb").read()
# Creates graph from saved GraphDef.
create_graph()
with tf_compat_v1.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name("softmax:0")
predictions = sess.run(softmax_tensor, {"DecodeJpeg/contents:0": image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(
label_lookup_path=map_proto_path, uid_lookup_path=label_path
)
# Print top 5 predictions from tensorflow.
top_k = predictions.argsort()[-5:][::-1]
print("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print("%s (score = %.5f)" % (human_string, score))
run_inference_on_image(img_path)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
===== TENSORFLOW RESULTS =======
African elephant, Loxodonta africana (score = 0.58394)
tusker (score = 0.33909)
Indian elephant, Elephas maximus (score = 0.03186)
banana (score = 0.00022)
desk (score = 0.00019)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 37.995 seconds)
.. _sphx_glr_download_how_to_compile_models_from_tensorflow.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: from_tensorflow.py <from_tensorflow.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: from_tensorflow.ipynb <from_tensorflow.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_