| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # ruff: noqa: E402, E501 |
| |
| """ |
| .. _import_model: |
| |
| Importing Models from ML Frameworks |
| ==================================== |
| Apache TVM supports importing models from popular ML frameworks including PyTorch, ONNX, |
| and TensorFlow Lite. This tutorial walks through each import path with a minimal working |
| example and explains the key parameters. The PyTorch section additionally demonstrates |
| how to handle unsupported operators via a custom converter map. |
| |
| For end-to-end optimization and deployment after importing, see :ref:`optimize_model`. |
| |
| .. note:: |
| |
| The ONNX section requires the ``onnx`` package. The TFLite section requires |
| ``tensorflow`` and ``tflite``. Sections whose dependencies are missing are skipped |
| automatically. |
| |
| .. contents:: Table of Contents |
| :local: |
| :depth: 2 |
| """ |
| |
| ###################################################################### |
| # Importing from PyTorch (Recommended) |
| # ------------------------------------- |
| # TVM's PyTorch frontend is the most feature-complete. The recommended entry point is |
| # :py:func:`~tvm.relax.frontend.torch.from_exported_program`, which works with PyTorch's |
| # ``torch.export`` API. |
| # |
| # We start by defining a small CNN model for demonstration. No pretrained weights are |
| # needed — we only care about the graph structure. |
| |
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.export import export |
| |
| import tvm |
| from tvm import relax |
| from tvm.relax.frontend.torch import from_exported_program |
| |
| |
| class SimpleCNN(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1) |
| self.bn = nn.BatchNorm2d(16) |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.fc = nn.Linear(16, 10) |
| |
| def forward(self, x): |
| x = torch.relu(self.bn(self.conv(x))) |
| x = self.pool(x).flatten(1) |
| x = self.fc(x) |
| return x |
| |
| |
| torch_model = SimpleCNN().eval() |
| example_args = (torch.randn(1, 3, 32, 32),) |
| |
| ###################################################################### |
| # Basic import |
| # ~~~~~~~~~~~~ |
| # The standard workflow is: ``torch.export.export()`` → ``from_exported_program()`` → |
| # ``detach_params()``. |
| |
| with torch.no_grad(): |
| exported_program = export(torch_model, example_args) |
| mod = from_exported_program( |
| exported_program, |
| keep_params_as_input=True, |
| unwrap_unit_return_tuple=True, |
| ) |
| |
| mod, params = relax.frontend.detach_params(mod) |
| mod.show() |
| |
| ###################################################################### |
| # Key parameters |
| # ~~~~~~~~~~~~~~ |
| # ``from_exported_program`` accepts several parameters that control how the model is |
| # translated: |
| # |
| # - **keep_params_as_input** (bool, default ``False``): When ``True``, model weights become |
| # function parameters, separated via ``relax.frontend.detach_params()``. When ``False``, |
| # weights are embedded as constants inside the IRModule. Use ``True`` when you want to |
| # manage weights independently (e.g., for weight sharing or quantization). |
| # |
| # - **unwrap_unit_return_tuple** (bool, default ``False``): PyTorch ``export`` always wraps |
| # the return value in a tuple. Set ``True`` to unwrap single-element return tuples for a |
| # cleaner Relax function signature. |
| # |
| # - **run_ep_decomposition** (bool, default ``True``): Runs PyTorch's built-in operator |
| # decomposition before translation. This breaks high-level ops (e.g., ``batch_norm``) into |
| # lower-level primitives, which generally improves TVM's coverage and optimization |
| # opportunities. Set ``False`` if you want to preserve the original op granularity. |
| |
| ###################################################################### |
| # Handling unsupported operators with ``custom_convert_map`` |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # When TVM encounters a PyTorch operator it does not recognize, it raises an error |
| # indicating the unsupported operator name. You can extend the frontend by providing a |
| # **custom converter map** — a dictionary mapping operator names to your own conversion |
| # functions. |
| # |
| # A custom converter function receives two arguments: |
| # |
| # - **node** (``torch.fx.Node``): The FX graph node being converted, carrying operator |
| # info and references to input nodes. |
| # - **importer** (``ExportedProgramImporter``): The importer instance, giving access to: |
| # |
| # - ``importer.env``: Dict mapping FX nodes to their converted Relax expressions. |
| # - ``importer.block_builder``: The Relax ``BlockBuilder`` for emitting operations. |
| # - ``importer.retrieve_args(node)``: Helper to look up converted args. |
| # |
| # The function must return a ``relax.Var`` — the Relax expression for this node's output. |
| # Here is an example that maps an operator to ``relax.op.sigmoid``: |
| |
| from tvm.relax.frontend.torch.exported_program_translator import ExportedProgramImporter |
| |
| |
| def convert_sigmoid(node: torch.fx.Node, importer: ExportedProgramImporter) -> relax.Var: |
| """Custom converter: map an op to relax.op.sigmoid.""" |
| args = importer.retrieve_args(node) |
| return importer.block_builder.emit(relax.op.sigmoid(args[0])) |
| |
| |
| ###################################################################### |
| # To use the custom converter, pass it via the ``custom_convert_map`` parameter. The key |
| # is the ATen operator name in ``"op_name.variant"`` format (e.g., ``"sigmoid.default"``): |
| # |
| # .. code-block:: python |
| # |
| # mod = from_exported_program( |
| # exported_program, |
| # custom_convert_map={"sigmoid.default": convert_sigmoid}, |
| # ) |
| # |
| # .. note:: |
| # |
| # To find the correct operator name, check the error message TVM raises when encountering |
| # the unsupported op — it includes the exact ATen name. You can also inspect the exported |
| # program's graph via ``print(exported_program.graph_module.graph)`` to see all operator |
| # names. |
| |
| ###################################################################### |
| # Alternative PyTorch import methods |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # Besides ``from_exported_program``, TVM also provides: |
| # |
| # - :py:func:`~tvm.relax.frontend.torch.from_fx`: Works with ``torch.fx.GraphModule`` |
| # from ``torch.fx.symbolic_trace()``. Requires explicit ``input_info`` (shapes and dtypes). |
| # Use this when ``torch.export`` fails on certain Python control flow patterns. |
| # |
| # - :py:func:`~tvm.relax.frontend.torch.relax_dynamo`: A ``torch.compile`` backend that |
| # compiles and executes the model through TVM in one step. Useful for integrating TVM |
| # into an existing PyTorch training or inference loop. |
| # |
| # - :py:func:`~tvm.relax.frontend.torch.dynamo_capture_subgraphs`: Captures subgraphs from |
| # a PyTorch model into an IRModule via ``torch.compile``. Each subgraph becomes a separate |
| # function in the IRModule. |
| # |
| # For most use cases, ``from_exported_program`` is the recommended path. |
| |
| ###################################################################### |
| # Verifying the imported model |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # After importing, it is good practice to verify that TVM produces the same output as the |
| # original framework. We compile with the minimal ``"zero"`` pipeline (no tuning) and |
| # compare. The same approach applies to models imported via the ONNX and TFLite frontends |
| # shown below. |
| |
| mod_compiled = relax.get_pipeline("zero")(mod) |
| exec_module = tvm.compile(mod_compiled, target="llvm") |
| dev = tvm.cpu() |
| vm = relax.VirtualMachine(exec_module, dev) |
| |
| # Run inference |
| input_data = np.random.rand(1, 3, 32, 32).astype("float32") |
| tvm_input = tvm.runtime.tensor(input_data, dev) |
| tvm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]] |
| tvm_out = vm["main"](tvm_input, *tvm_params).numpy() |
| |
| # Compare with PyTorch |
| with torch.no_grad(): |
| pt_out = torch_model(torch.from_numpy(input_data)).numpy() |
| |
| np.testing.assert_allclose(tvm_out, pt_out, rtol=1e-5, atol=1e-5) |
| print("PyTorch vs TVM outputs match!") |
| |
| ###################################################################### |
| # Importing from ONNX |
| # -------------------- |
| # TVM can import ONNX models via :py:func:`~tvm.relax.frontend.onnx.from_onnx`. The |
| # function accepts an ``onnx.ModelProto`` object, so you need to load the model with |
| # ``onnx.load()`` first. |
| # |
| # Here we export the same CNN model to ONNX format and then import it into TVM. |
| |
| try: |
| import onnx |
| import onnxscript # noqa: F401 # required by torch.onnx.export |
| |
| HAS_ONNX = True |
| except ImportError: |
| onnx = None # type: ignore[assignment] |
| HAS_ONNX = False |
| |
| if HAS_ONNX: |
| from tvm.relax.frontend.onnx import from_onnx |
| |
| # Export the PyTorch model to ONNX |
| dummy_input = torch.randn(1, 3, 32, 32) |
| onnx_path = "simple_cnn.onnx" |
| torch.onnx.export(torch_model, dummy_input, onnx_path, input_names=["input"]) |
| |
| # Load and import into TVM |
| onnx_model = onnx.load(onnx_path) |
| mod_onnx = from_onnx(onnx_model, keep_params_in_input=True) |
| mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx) |
| mod_onnx.show() |
| |
| ###################################################################### |
| # If you already have an ``.onnx`` file on disk, the workflow is even simpler: |
| # |
| # .. code-block:: python |
| # |
| # import onnx |
| # from tvm.relax.frontend.onnx import from_onnx |
| # |
| # onnx_model = onnx.load("my_model.onnx") |
| # mod = from_onnx(onnx_model) |
| # |
| |
| ###################################################################### |
| # Key parameters |
| # ~~~~~~~~~~~~~~ |
| # - **shape_dict** (dict, optional): Maps input names to shapes. Auto-inferred from the |
| # model if not provided. Useful when the ONNX model has dynamic dimensions that you |
| # want to fix to concrete sizes: |
| # |
| # .. code-block:: python |
| # |
| # mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]}) |
| # |
| # - **dtype_dict** (str or dict, default ``"float32"``): Input dtypes. A single string |
| # applies to all inputs, or use a dict to set per-input dtypes: |
| # |
| # .. code-block:: python |
| # |
| # mod = from_onnx(onnx_model, dtype_dict={"input": "float16"}) |
| # |
| # - **keep_params_in_input** (bool, default ``False``): Same semantics as PyTorch — whether |
| # model weights are function parameters or embedded constants. |
| # |
| # - **opset** (int, optional): Override the opset version auto-detected from the model. |
| # Each ONNX op may have different semantics across opset versions; TVM's converter |
| # selects the appropriate implementation automatically. You rarely need to set this |
| # unless the model metadata is incorrect. |
| |
| ###################################################################### |
| # Importing from TensorFlow Lite |
| # ------------------------------- |
| # TVM can import TFLite flat buffer models via |
| # :py:func:`~tvm.relax.frontend.tflite.from_tflite`. The function expects a TFLite |
| # ``Model`` object parsed from flat buffer bytes via ``GetRootAsModel``. |
| # |
| # .. note:: |
| # |
| # The ``tflite`` Python package has changed its module layout across versions. |
| # Older versions use ``tflite.Model.Model.GetRootAsModel``, while newer versions use |
| # ``tflite.Model.GetRootAsModel``. The code below handles both. |
| # |
| # Below we create a minimal TFLite model from TensorFlow and import it. |
| |
| try: |
| import tensorflow as tf |
| import tflite |
| import tflite.Model |
| |
| HAS_TFLITE = True |
| except ImportError: |
| HAS_TFLITE = False |
| |
| if HAS_TFLITE: |
| from tvm.relax.frontend.tflite import from_tflite |
| |
| # Define a simple TF module and convert to TFLite. |
| # We use plain TF ops (not keras layers) to avoid variable-handling ops |
| # that some TFLite converter versions do not support cleanly. |
| class TFModule(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(1, 784), dtype=tf.float32), |
| tf.TensorSpec(shape=(784, 10), dtype=tf.float32), |
| ] |
| ) |
| def forward(self, x, weight): |
| return tf.matmul(x, weight) + 0.1 |
| |
| tf_module = TFModule() |
| converter = tf.lite.TFLiteConverter.from_concrete_functions( |
| [tf_module.forward.get_concrete_function()], tf_module |
| ) |
| tflite_buf = converter.convert() |
| |
| # Parse and import into TVM (API differs between tflite package versions) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0) |
| mod_tflite = from_tflite(tflite_model) |
| mod_tflite.show() |
| |
| ###################################################################### |
| # Loading from a ``.tflite`` file |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # If you already have a ``.tflite`` file on disk, load the raw bytes and parse them: |
| # |
| # .. code-block:: python |
| # |
| # import tflite |
| # import tflite.Model |
| # from tvm.relax.frontend.tflite import from_tflite |
| # |
| # with open("my_model.tflite", "rb") as f: |
| # tflite_buf = f.read() |
| # |
| # if hasattr(tflite.Model, "Model"): |
| # tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0) |
| # else: |
| # tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0) |
| # mod = from_tflite(tflite_model) |
| |
| ###################################################################### |
| # Key parameters |
| # ~~~~~~~~~~~~~~ |
| # - **shape_dict** / **dtype_dict** (optional): Override input shapes and dtypes. If not |
| # provided, they are inferred from the TFLite model metadata. |
| # |
| # - **op_converter** (class, optional): A custom operator converter class. Subclass |
| # ``OperatorConverter`` and override its ``convert_map`` dictionary to add or replace |
| # operator conversions. For example, to add a hypothetical ``CUSTOM_RELU`` op: |
| # |
| # .. code-block:: python |
| # |
| # from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter |
| # |
| # class MyConverter(OperatorConverter): |
| # def __init__(self, model, subgraph, exp_tab, ctx): |
| # super().__init__(model, subgraph, exp_tab, ctx) |
| # self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu |
| # |
| # def _convert_custom_relu(self, op): |
| # # implement your conversion logic here |
| # ... |
| # |
| # mod = from_tflite(tflite_model, op_converter=MyConverter) |
| |
| ###################################################################### |
| # Summary |
| # ------- |
| # |
| # +---------------------+----------------------------+-------------------------------+-----------------------------+ |
| # | Aspect | PyTorch | ONNX | TFLite | |
| # +=====================+============================+===============================+=============================+ |
| # | Entry function | ``from_exported_program`` | ``from_onnx`` | ``from_tflite`` | |
| # +---------------------+----------------------------+-------------------------------+-----------------------------+ |
| # | Input | ``ExportedProgram`` | ``onnx.ModelProto`` | TFLite ``Model`` object | |
| # +---------------------+----------------------------+-------------------------------+-----------------------------+ |
| # | Custom extension | ``custom_convert_map`` | — | ``op_converter`` class | |
| # +---------------------+----------------------------+-------------------------------+-----------------------------+ |
| # |
| # **Which to use?** Pick the frontend that matches your model format: |
| # |
| # - Have a PyTorch model? Use ``from_exported_program`` — it has the broadest operator coverage. |
| # - Have an ``.onnx`` file? Use ``from_onnx``. |
| # - Have a ``.tflite`` file? Use ``from_tflite``. |
| # |
| # The verification workflow (compile → run → compare) demonstrated in the PyTorch section |
| # above applies equally to ONNX and TFLite imports. |
| # |
| # For the full list of supported operators, see the converter map in each frontend's source: |
| # PyTorch uses ``create_convert_map()`` in ``exported_program_translator.py``, ONNX uses |
| # ``_get_convert_map()`` in ``onnx_frontend.py``, and TFLite uses ``convert_map`` in |
| # ``OperatorConverter`` in ``tflite_frontend.py``. |
| # |
| # After importing, refer to :ref:`optimize_model` for optimization and deployment. |