blob: 222329ce0bc4b1aefde40def3a238d95bf627fdd [file]
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
.. _relax-vm-arch:
Relax Virtual Machine
=====================
This document explains the Relax VM architecture in detail, covering the compilation pipeline
from Relax IR to bytecode, the instruction set, the execution model, and the Python-level user
interface.
Overview
--------
The end-to-end flow from model to execution is:
1. **Relax IR** — a high-level computational graph (``relax.Function`` inside an ``IRModule``).
2. **Compilation** — ``tvm.compile()`` applies the Relax transformation pipeline, then invokes
``VMCodeGen`` to translate each Relax function into bytecode instructions.
3. **Linking** — TIR functions are compiled to native kernels (via LLVM, CUDA, etc.); the bytecode,
constant pool, and compiled kernels are packaged together into a ``VMExecutable``.
4. **Execution** — at runtime, a ``VirtualMachine`` loads the executable, initializes devices and
memory allocators, and runs the bytecode.
.. code-block:: text
IRModule (Relax + TIR)
▼ relax_pipeline (FuseOps, LegalizeOps, ...)
IRModule (optimized)
▼ VMCodeGen
ExecBuilder (bytecode) + IRModule (TIR only)
│ │
│ ▼ tirx.build()
│ runtime.Module (native kernels)
│ │
▼ VMLink ▼
VMExecutable ◄───────── linked together
▼ VirtualMachine(exec, device)
Runtime execution
Compilation: From Relax IR to Bytecode
--------------------------------------
Build entry point
~~~~~~~~~~~~~~~~~
The main entry point is ``tvm.compile()`` (which delegates to ``relax.build()`` in
``python/tvm/relax/vm_build.py``):
.. code-block:: python
import tvm
from tvm import relax
@tvm.script.ir_module
class MyModule:
@R.function
def main(x: R.Tensor((3, 4), "float32")):
return R.add(x, x)
target = tvm.target.Target("llvm")
ex = tvm.compile(MyModule, target)
Internally, ``relax.build()`` performs these steps:
1. Apply the **Relax pipeline** (``relax.get_pipeline("default")``), which includes operator
legalization, fusion, buffer planning, and other graph-level passes.
2. Create an ``ExecBuilder`` and run **VMCodeGen** (``src/relax/backend/vm/codegen_vm.cc``),
which walks each ``relax.Function`` and emits bytecode instructions. The Relax functions are
removed from the IRModule; only TIR functions remain.
3. Compile the remaining TIR functions to native code via ``tirx.build()``.
4. **Link** the bytecode executable with the compiled native module using ``VMLink``, producing
a ``VMExecutable``.
Two execution modes are supported:
- ``exec_mode="bytecode"`` (default): Relax functions are interpreted by the VM's bytecode
dispatch loop.
- ``exec_mode="compiled"``: Relax functions are compiled into TIR functions (``VMTIRCodeGen``)
that directly manipulate the register file, bypassing the interpreter loop. This avoids
dispatch overhead but produces more code.
Bytecode generation
~~~~~~~~~~~~~~~~~~~
The ``CodeGenVM`` class (``src/relax/backend/vm/codegen_vm.cc``) is an ``ExprFunctor`` that visits
each Relax expression and emits instructions through the ``ExecBuilder``:
- Each ``relax.Var`` is mapped to a register.
- Function parameters occupy registers 0 through N-1.
- Each binding in a ``SeqExpr`` generates one or more instructions; the result is stored in a
new register.
- Function calls (``R.call_tir``, ``R.call_packed``, operator calls) become ``Call`` instructions.
- Conditional expressions (``relax.If``, written as Python ``if`` in TVMScript) become an ``If``
instruction followed by ``Goto`` to skip branches.
- The function body ends with a ``Ret`` instruction.
Instruction Set
---------------
The VM uses a **register-based** architecture with an intentionally minimal instruction set.
There are only four opcodes:
.. list-table::
:header-rows: 1
:widths: 15 30 55
* - Opcode
- Fields
- Semantics
* - ``Call``
- ``dst``, ``func_idx``, ``num_args``, ``args[]``
- Call function ``func_idx`` with the given arguments; store the result in register ``dst``.
* - ``Ret``
- ``result``
- Return the value in register ``result`` to the caller.
* - ``Goto``
- ``pc_offset``
- Jump forward or backward by ``pc_offset`` instructions.
* - ``If``
- ``cond``, ``false_offset``
- If register ``cond`` is nonzero, fall through (pc++); otherwise jump by ``false_offset``.
The VM itself performs **no mathematical computation**. All actual work — matrix multiplications,
convolutions, elementwise operations — is carried out by compiled TIR kernels or external
libraries (cuBLAS, cuDNN, etc.), dispatched through ``Call`` instructions.
Instruction encoding
~~~~~~~~~~~~~~~~~~~~
Each instruction argument (``Instruction::Arg``) is a 64-bit word encoded as:
- **Bits [63:56]** — ``ArgKind`` (8 bits): ``kRegister`` (0), ``kImmediate`` (1), ``kConstIdx`` (2),
or ``kFuncIdx`` (3).
- **Bits [55:0]** — value (56 bits, sign-extended).
Two special register values exist:
- ``kVoidRegister``: indicates "no destination" (the return value is discarded).
- ``kVMRegister``: refers to the VM context pointer itself, passed as the first argument to
closures.
The instruction stream is stored as a flat ``vector<ExecWord>`` (``instr_data``) with an offset
table (``instr_offset``) for random access.
Executable
----------
A ``VMExecutable`` (``include/tvm/runtime/vm/executable.h``) bundles everything needed for
execution:
- **Function table** (``func_table``): a ``vector<VMFuncInfo>`` describing every function. Each
entry records the function's kind, name, instruction range (``start_instr`` to ``end_instr``),
number of arguments, register file size, and parameter names.
- **Constant pool** (``constants``): model weights, shape tuples, and other compile-time constants.
- **Bytecode** (``instr_data`` + ``instr_offset``): the instruction stream.
- **Imported modules**: the compiled TIR kernels and external libraries.
Function kinds
~~~~~~~~~~~~~~
The VM recognizes three function kinds (``VMFuncInfo::FuncKind``):
.. list-table::
:header-rows: 1
:widths: 20 80
* - Kind
- Description
* - ``kPackedFunc``
- An external C/C++ function looked up from imported modules or the global PackedFunc
registry. Examples: ``vm.builtin.alloc_shape_heap``, ``vm.builtin.match_shape``.
* - ``kVMFunc``
- A bytecode-interpreted Relax function. The VM interprets its instructions in ``RunLoop()``.
* - ``kVMTIRFunc``
- A Relax function compiled to a TIR function (``exec_mode="compiled"``). Found in
imports under the name ``__vmtir__<func_name>``. Called directly with register file
pointers, bypassing the interpreter loop.
Serialization
~~~~~~~~~~~~~
The executable supports binary serialization for deployment:
.. code-block:: python
# Save
ex.export_library("model.so")
# Load
loaded = tvm.runtime.load_module("model.so")
vm = relax.VirtualMachine(loaded, tvm.cuda())
The binary format includes a magic number (``0xD225DE2F4214151E``), a version string
(currently ``"0.14"``), followed by four sections: globals (the function table), memory scopes,
constant pool, and bytecode. ``AsText()`` and ``AsPython()`` provide human-readable representations
for debugging.
Runtime Execution
-----------------
VM initialization
~~~~~~~~~~~~~~~~~
At runtime, a ``VirtualMachine`` is created and initialized:
.. code-block:: python
from tvm.relax import VirtualMachine
vm = VirtualMachine(exec_module, tvm.cuda())
Under the hood:
1. **LoadExecutable**: the bytecode and metadata are loaded from the ``VMExecutable``.
2. **Init**: devices and memory allocators are set up. Each device gets an ``Allocator``
(either ``NAIVE_ALLOCATOR`` or ``POOLED_ALLOCATOR``, defaulting to pooled). A CPU device
is always added for shape computations.
3. **InitFuncPool**: the function pool is populated — ``kPackedFunc`` entries are resolved from
imports or the global registry; ``kVMFunc`` and ``kVMTIRFunc`` entries are wrapped in
``VMClosure`` objects.
4. **Constant pool**: model constants are loaded and optionally transferred to the target device.
The bytecode dispatch loop
~~~~~~~~~~~~~~~~~~~~~~~~~~
When a ``kVMFunc`` is invoked, the VM enters ``InvokeBytecode()``:
1. A new ``VMFrame`` is pushed onto the call stack. Each frame contains:
- A **register file** (``vector<ffi::Any>``) — type-erased slots that can hold tensors,
shapes, closures, or any TVM object. The size is determined at compile time
(``VMFuncInfo::register_file_size``).
- The **return program counter** — where to resume after the function returns.
- The **caller's return register** — which register in the parent frame receives the result.
2. Function arguments are written to registers 0..N-1.
3. The program counter (``pc_``) is set to the function's ``start_instr``.
4. ``RunLoop()`` executes instructions until a ``Ret`` is encountered:
- **Call**: resolve arguments (from registers, immediates, constant pool, or function pool),
invoke the target function via ``InvokeClosurePacked()``, store the result in ``dst``.
- **Ret**: read the return value from the specified register, write the result to the
caller's return register, and return from ``RunLoop()`` (the frame is popped by an RAII
guard when ``InvokeBytecode()`` exits).
- **Goto**: adjust ``pc_`` by the offset.
- **If**: check the condition register; if nonzero, fall through; otherwise jump by
``false_offset``.
The dispatch loop is implemented in ``src/runtime/vm/vm.cc`` (``VirtualMachineImpl::RunLoop``).
.. code-block:: text
Frame Stack Register File (per frame)
┌─────────────┐ ┌────┬────┬────┬─────┬────┐
│ Frame 2 │ ───────► │ R0 │ R1 │ R2 │ ... │ Rn │
├─────────────┤ └────┴────┴────┴─────┴────┘
│ Frame 1 │ ───────► [register file]
├─────────────┤
│ Frame 0 │ ───────► [register file]
└─────────────┘
VMClosure and function dispatch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Functions in the VM are stored in a ``func_pool_`` indexed by function table position.
``kVMFunc`` and ``kVMTIRFunc`` entries are wrapped as ``VMClosure`` objects, while ``kPackedFunc``
entries are stored as plain ``ffi::Function``. A ``VMClosure`` stores:
- ``func_name``: the function's string name.
- ``impl``: a ``ffi::Function`` that takes the VM context pointer as its first argument, followed
by the actual parameters.
When the VM encounters a ``Call`` instruction, it looks up the function in ``func_pool_`` by
index and dispatches via ``InvokeClosurePacked()``. If the target is a ``VMClosure``, the VM
pointer is prepended to the arguments and ``impl`` is invoked. If it is a plain
``ffi::Function``, it is called directly.
``VMClosure::BindLastArgs`` enables partial application — it creates a new function with
some arguments pre-bound at the end, useful for implementing captured closures in Relax.
Built-in operations
~~~~~~~~~~~~~~~~~~~
The VM relies on several built-in PackedFuncs (registered in ``src/runtime/vm/builtin.cc``)
for runtime support:
- ``vm.builtin.alloc_shape_heap``: allocate workspace for symbolic shape computations.
- ``vm.builtin.match_shape``: validate tensor shapes against expected patterns at runtime,
supporting assertions (``kAssertEqualToImm``, ``kAssertEqualToLoad``), storing symbolic
dimensions to the shape heap (``kStoreToHeap``), or no-ops (``kNoOp``).
- ``vm.builtin.make_shape``: construct shape tuples from immediates or heap-loaded values.
- ``vm.builtin.match_prim_value``: validate primitive values (e.g., integers) against expected
patterns.
- ``vm.builtin.copy``: copy a value into a register. Used in several codegen scenarios:
materializing non-register arguments (immediates, constants) into registers, ensuring each
variable binding gets its own register, and merging results from if/else branches.
Python Interface
----------------
Users interact with the VM through ``tvm.relax.VirtualMachine``:
.. code-block:: python
import tvm
from tvm import relax
import numpy as np
# Compile
ex = tvm.compile(MyModule, target="llvm")
# Create VM
vm = relax.VirtualMachine(ex, tvm.cpu())
# Direct invocation
inp = tvm.runtime.tensor(np.random.rand(3, 4).astype("float32"))
result = vm["main"](inp)
# Stateful interface (useful for RPC)
vm.set_input("main", inp)
vm.invoke_stateful("main")
output = vm.get_outputs("main")
Key methods:
- ``vm["func_name"](*args)`` — direct invocation, returns the result.
- ``vm.set_input()`` / ``vm.invoke_stateful()`` / ``vm.get_outputs()`` — stateful interface
that avoids sending output over the wire, useful for RPC-based remote execution.
- ``vm.save_function(func_name, saved_name, *args)`` — pre-bind arguments for repeated calls,
reducing dictionary lookup overhead during benchmarking.
- ``vm.time_evaluator(func_name, dev)`` — returns a timing function following the same convention
as ``tvm.runtime.Module.time_evaluator``.
- ``vm.set_instrument(func)`` — register an instrumentation callback that is invoked before/after
every ``Call`` instruction. The callback can return ``VMInstrumentReturnKind.SKIP_RUN`` to
skip the call.
Instrumentation
~~~~~~~~~~~~~~~
The VM supports observability via instrumentation:
**Instrumentation** via ``set_instrument()``:
.. code-block:: python
def my_instrument(func, func_symbol, before_run, ret_value, *args):
if before_run:
print(f"About to call: {func_symbol}")
return VMInstrumentReturnKind.NO_OP
vm.set_instrument(my_instrument)
vm["main"](inp)
The instrument function is called before and after every ``Call`` instruction, receiving the
function object, its symbol name, a flag indicating before/after, the return value (only valid
after), and all arguments.
Inspecting Bytecode
-------------------
The executable provides text and Python representations of the compiled bytecode:
.. code-block:: python
ex = tvm.compile(MyModule, target="llvm")
print(ex.as_text()) # Human-readable instruction listing
print(ex.as_python()) # Equivalent Python program
print(ex.stats()) # Summary statistics
These are invaluable for debugging compilation issues — they show exactly which functions
are called, in what order, and how registers are used.
Source Code Map
---------------
.. list-table::
:header-rows: 1
:widths: 45 55
* - Path
- Contents
* - ``include/tvm/runtime/vm/bytecode.h``
- Instruction, Opcode, and Arg definitions
* - ``include/tvm/runtime/vm/executable.h``
- VMExecutable, VMFuncInfo, serialization
* - ``include/tvm/runtime/vm/vm.h``
- VirtualMachine base class, VMClosure
* - ``src/runtime/vm/vm.cc``
- VirtualMachineImpl, RunLoop, InvokeBytecode
* - ``src/runtime/vm/executable.cc``
- Serialization/deserialization, text output
* - ``src/runtime/vm/builtin.cc``
- Built-in operations (shape matching, allocation)
* - ``src/relax/backend/vm/codegen_vm.cc``
- CodeGenVM: Relax IR → bytecode
* - ``src/relax/backend/vm/codegen_vm_tir.cc``
- VMTIRCodeGen: Relax IR → compiled TIR
* - ``python/tvm/runtime/vm.py``
- Python VirtualMachine wrapper
* - ``python/tvm/relax/vm_build.py``
- ``relax.build()`` and VMExecutable Python class