blob: 888235a85992541fd557fa3426ebb0c32452dcfc [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: E402, E501
"""
.. _import_model:
Importing Models from ML Frameworks
====================================
Apache TVM supports importing models from popular ML frameworks including PyTorch, ONNX,
and TensorFlow Lite. This tutorial walks through each import path with a minimal working
example and explains the key parameters. The PyTorch section additionally demonstrates
how to handle unsupported operators via a custom converter map.
For end-to-end optimization and deployment after importing, see :ref:`optimize_model`.
.. note::
The ONNX section requires the ``onnx`` package. The TFLite section requires
``tensorflow`` and ``tflite``. Sections whose dependencies are missing are skipped
automatically.
.. contents:: Table of Contents
:local:
:depth: 2
"""
######################################################################
# Importing from PyTorch (Recommended)
# -------------------------------------
# TVM's PyTorch frontend is the most feature-complete. The recommended entry point is
# :py:func:`~tvm.relax.frontend.torch.from_exported_program`, which works with PyTorch's
# ``torch.export`` API.
#
# We start by defining a small CNN model for demonstration. No pretrained weights are
# needed — we only care about the graph structure.
import numpy as np
import torch
from torch import nn
from torch.export import export
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = torch.relu(self.bn(self.conv(x)))
x = self.pool(x).flatten(1)
x = self.fc(x)
return x
torch_model = SimpleCNN().eval()
example_args = (torch.randn(1, 3, 32, 32),)
######################################################################
# Basic import
# ~~~~~~~~~~~~
# The standard workflow is: ``torch.export.export()`` → ``from_exported_program()`` →
# ``detach_params()``.
with torch.no_grad():
exported_program = export(torch_model, example_args)
mod = from_exported_program(
exported_program,
keep_params_as_input=True,
unwrap_unit_return_tuple=True,
)
mod, params = relax.frontend.detach_params(mod)
mod.show()
######################################################################
# Key parameters
# ~~~~~~~~~~~~~~
# ``from_exported_program`` accepts several parameters that control how the model is
# translated:
#
# - **keep_params_as_input** (bool, default ``False``): When ``True``, model weights become
# function parameters, separated via ``relax.frontend.detach_params()``. When ``False``,
# weights are embedded as constants inside the IRModule. Use ``True`` when you want to
# manage weights independently (e.g., for weight sharing or quantization).
#
# - **unwrap_unit_return_tuple** (bool, default ``False``): PyTorch ``export`` always wraps
# the return value in a tuple. Set ``True`` to unwrap single-element return tuples for a
# cleaner Relax function signature.
#
# - **run_ep_decomposition** (bool, default ``True``): Runs PyTorch's built-in operator
# decomposition before translation. This breaks high-level ops (e.g., ``batch_norm``) into
# lower-level primitives, which generally improves TVM's coverage and optimization
# opportunities. Set ``False`` if you want to preserve the original op granularity.
######################################################################
# Handling unsupported operators with ``custom_convert_map``
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# When TVM encounters a PyTorch operator it does not recognize, it raises an error
# indicating the unsupported operator name. You can extend the frontend by providing a
# **custom converter map** — a dictionary mapping operator names to your own conversion
# functions.
#
# A custom converter function receives two arguments:
#
# - **node** (``torch.fx.Node``): The FX graph node being converted, carrying operator
# info and references to input nodes.
# - **importer** (``ExportedProgramImporter``): The importer instance, giving access to:
#
# - ``importer.env``: Dict mapping FX nodes to their converted Relax expressions.
# - ``importer.block_builder``: The Relax ``BlockBuilder`` for emitting operations.
# - ``importer.retrieve_args(node)``: Helper to look up converted args.
#
# The function must return a ``relax.Var`` — the Relax expression for this node's output.
# Here is an example that maps an operator to ``relax.op.sigmoid``:
from tvm.relax.frontend.torch.exported_program_translator import ExportedProgramImporter
def convert_sigmoid(node: torch.fx.Node, importer: ExportedProgramImporter) -> relax.Var:
"""Custom converter: map an op to relax.op.sigmoid."""
args = importer.retrieve_args(node)
return importer.block_builder.emit(relax.op.sigmoid(args[0]))
######################################################################
# To use the custom converter, pass it via the ``custom_convert_map`` parameter. The key
# is the ATen operator name in ``"op_name.variant"`` format (e.g., ``"sigmoid.default"``):
#
# .. code-block:: python
#
# mod = from_exported_program(
# exported_program,
# custom_convert_map={"sigmoid.default": convert_sigmoid},
# )
#
# .. note::
#
# To find the correct operator name, check the error message TVM raises when encountering
# the unsupported op — it includes the exact ATen name. You can also inspect the exported
# program's graph via ``print(exported_program.graph_module.graph)`` to see all operator
# names.
######################################################################
# Alternative PyTorch import methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Besides ``from_exported_program``, TVM also provides:
#
# - :py:func:`~tvm.relax.frontend.torch.from_fx`: Works with ``torch.fx.GraphModule``
# from ``torch.fx.symbolic_trace()``. Requires explicit ``input_info`` (shapes and dtypes).
# Use this when ``torch.export`` fails on certain Python control flow patterns.
#
# - :py:func:`~tvm.relax.frontend.torch.relax_dynamo`: A ``torch.compile`` backend that
# compiles and executes the model through TVM in one step. Useful for integrating TVM
# into an existing PyTorch training or inference loop.
#
# - :py:func:`~tvm.relax.frontend.torch.dynamo_capture_subgraphs`: Captures subgraphs from
# a PyTorch model into an IRModule via ``torch.compile``. Each subgraph becomes a separate
# function in the IRModule.
#
# For most use cases, ``from_exported_program`` is the recommended path.
######################################################################
# Verifying the imported model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# After importing, it is good practice to verify that TVM produces the same output as the
# original framework. We compile with the minimal ``"zero"`` pipeline (no tuning) and
# compare. The same approach applies to models imported via the ONNX and TFLite frontends
# shown below.
mod_compiled = relax.get_pipeline("zero")(mod)
exec_module = tvm.compile(mod_compiled, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec_module, dev)
# Run inference
input_data = np.random.rand(1, 3, 32, 32).astype("float32")
tvm_input = tvm.runtime.tensor(input_data, dev)
tvm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
tvm_out = vm["main"](tvm_input, *tvm_params).numpy()
# Compare with PyTorch
with torch.no_grad():
pt_out = torch_model(torch.from_numpy(input_data)).numpy()
np.testing.assert_allclose(tvm_out, pt_out, rtol=1e-5, atol=1e-5)
print("PyTorch vs TVM outputs match!")
######################################################################
# Importing from ONNX
# --------------------
# TVM can import ONNX models via :py:func:`~tvm.relax.frontend.onnx.from_onnx`. The
# function accepts an ``onnx.ModelProto`` object, so you need to load the model with
# ``onnx.load()`` first.
#
# Here we export the same CNN model to ONNX format and then import it into TVM.
try:
import onnx
import onnxscript # noqa: F401 # required by torch.onnx.export
HAS_ONNX = True
except ImportError:
onnx = None # type: ignore[assignment]
HAS_ONNX = False
if HAS_ONNX:
from tvm.relax.frontend.onnx import from_onnx
# Export the PyTorch model to ONNX
dummy_input = torch.randn(1, 3, 32, 32)
onnx_path = "simple_cnn.onnx"
torch.onnx.export(torch_model, dummy_input, onnx_path, input_names=["input"])
# Load and import into TVM
onnx_model = onnx.load(onnx_path)
mod_onnx = from_onnx(onnx_model, keep_params_in_input=True)
mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx)
mod_onnx.show()
######################################################################
# If you already have an ``.onnx`` file on disk, the workflow is even simpler:
#
# .. code-block:: python
#
# import onnx
# from tvm.relax.frontend.onnx import from_onnx
#
# onnx_model = onnx.load("my_model.onnx")
# mod = from_onnx(onnx_model)
#
######################################################################
# Key parameters
# ~~~~~~~~~~~~~~
# - **shape_dict** (dict, optional): Maps input names to shapes. Auto-inferred from the
# model if not provided. Useful when the ONNX model has dynamic dimensions that you
# want to fix to concrete sizes:
#
# .. code-block:: python
#
# mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]})
#
# - **dtype_dict** (str or dict, default ``"float32"``): Input dtypes. A single string
# applies to all inputs, or use a dict to set per-input dtypes:
#
# .. code-block:: python
#
# mod = from_onnx(onnx_model, dtype_dict={"input": "float16"})
#
# - **keep_params_in_input** (bool, default ``False``): Same semantics as PyTorch — whether
# model weights are function parameters or embedded constants.
#
# - **opset** (int, optional): Override the opset version auto-detected from the model.
# Each ONNX op may have different semantics across opset versions; TVM's converter
# selects the appropriate implementation automatically. You rarely need to set this
# unless the model metadata is incorrect.
######################################################################
# Importing from TensorFlow Lite
# -------------------------------
# TVM can import TFLite flat buffer models via
# :py:func:`~tvm.relax.frontend.tflite.from_tflite`. The function expects a TFLite
# ``Model`` object parsed from flat buffer bytes via ``GetRootAsModel``.
#
# .. note::
#
# The ``tflite`` Python package has changed its module layout across versions.
# Older versions use ``tflite.Model.Model.GetRootAsModel``, while newer versions use
# ``tflite.Model.GetRootAsModel``. The code below handles both.
#
# Below we create a minimal TFLite model from TensorFlow and import it.
try:
import tensorflow as tf
import tflite
import tflite.Model
HAS_TFLITE = True
except ImportError:
HAS_TFLITE = False
if HAS_TFLITE:
from tvm.relax.frontend.tflite import from_tflite
# Define a simple TF module and convert to TFLite.
# We use plain TF ops (not keras layers) to avoid variable-handling ops
# that some TFLite converter versions do not support cleanly.
class TFModule(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(1, 784), dtype=tf.float32),
tf.TensorSpec(shape=(784, 10), dtype=tf.float32),
]
)
def forward(self, x, weight):
return tf.matmul(x, weight) + 0.1
tf_module = TFModule()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_module.forward.get_concrete_function()], tf_module
)
tflite_buf = converter.convert()
# Parse and import into TVM (API differs between tflite package versions)
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
mod_tflite = from_tflite(tflite_model)
mod_tflite.show()
######################################################################
# Loading from a ``.tflite`` file
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# If you already have a ``.tflite`` file on disk, load the raw bytes and parse them:
#
# .. code-block:: python
#
# import tflite
# import tflite.Model
# from tvm.relax.frontend.tflite import from_tflite
#
# with open("my_model.tflite", "rb") as f:
# tflite_buf = f.read()
#
# if hasattr(tflite.Model, "Model"):
# tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
# else:
# tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
# mod = from_tflite(tflite_model)
######################################################################
# Key parameters
# ~~~~~~~~~~~~~~
# - **shape_dict** / **dtype_dict** (optional): Override input shapes and dtypes. If not
# provided, they are inferred from the TFLite model metadata.
#
# - **op_converter** (class, optional): A custom operator converter class. Subclass
# ``OperatorConverter`` and override its ``convert_map`` dictionary to add or replace
# operator conversions. For example, to add a hypothetical ``CUSTOM_RELU`` op:
#
# .. code-block:: python
#
# from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter
#
# class MyConverter(OperatorConverter):
# def __init__(self, model, subgraph, exp_tab, ctx):
# super().__init__(model, subgraph, exp_tab, ctx)
# self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu
#
# def _convert_custom_relu(self, op):
# # implement your conversion logic here
# ...
#
# mod = from_tflite(tflite_model, op_converter=MyConverter)
######################################################################
# Summary
# -------
#
# +---------------------+----------------------------+-------------------------------+-----------------------------+
# | Aspect | PyTorch | ONNX | TFLite |
# +=====================+============================+===============================+=============================+
# | Entry function | ``from_exported_program`` | ``from_onnx`` | ``from_tflite`` |
# +---------------------+----------------------------+-------------------------------+-----------------------------+
# | Input | ``ExportedProgram`` | ``onnx.ModelProto`` | TFLite ``Model`` object |
# +---------------------+----------------------------+-------------------------------+-----------------------------+
# | Custom extension | ``custom_convert_map`` | — | ``op_converter`` class |
# +---------------------+----------------------------+-------------------------------+-----------------------------+
#
# **Which to use?** Pick the frontend that matches your model format:
#
# - Have a PyTorch model? Use ``from_exported_program`` — it has the broadest operator coverage.
# - Have an ``.onnx`` file? Use ``from_onnx``.
# - Have a ``.tflite`` file? Use ``from_tflite``.
#
# The verification workflow (compile → run → compare) demonstrated in the PyTorch section
# above applies equally to ONNX and TFLite imports.
#
# For the full list of supported operators, see the converter map in each frontend's source:
# PyTorch uses ``create_convert_map()`` in ``exported_program_translator.py``, ONNX uses
# ``_get_convert_map()`` in ``onnx_frontend.py``, and TFLite uses ``convert_map`` in
# ``OperatorConverter`` in ``tflite_frontend.py``.
#
# After importing, refer to :ref:`optimize_model` for optimization and deployment.