Summary

This RFC proposes a new TVMScript parser infrastructure, supporting extensive metaprogramming and syntactic sugars. The new infrastructure is IR-agnostic, treating TIR just as one of dialects. Additionally, the new infrastructure will provide better tooling around Python ecosystem (pylint, mypy, etc.).

Motivation

What is TVMScript. Check Blitz Course to TensorIR and TVMScript Unified Printer RFC for an introduction into TVMScript.

What is metaprogramming. In the context of TVMScript, metaprogramming means a programmable way to control IR generation. For example, in https://github.com/apache/tvm/pull/11097, a metaprogramming feature was added to the TVMScript parser, allows users to programmably control the shapes of the input buffers of a PrimFunc.

Limitation of current design

The current parser lacks capability on generic metaprogramming that allows user to have more control on IR construction. This makes it challenging to support operators like NMS (non-maximum suppression, which is crucial to object detection model). There is an implementation of NMS at python/tvm/topi/cuda/nms.py#L367-L386. The implementation of NMS-like operators requires rank-polymorphism and the ability to interleave host program with TVMScript, which is difficult to be implemented under the current design.

TVMScript also needs reasonable support on Python tooling. Currently it doesn’t play nicely with pylint and mypy. For example, test_meta_schedule_postproc_rewrite_tensorize.py has 100+ warnings from pylint within only 500 hundred lines of code. This creates confusion to the user and leaves an impression that TVMScript isn’t a mature product and not production-ready. Even though it’s something that can be incrementally improved under the current design, we believe it’s easier to get an ideal result if we have a design with the tooling support in mind.

The current design also lacks of unified approach for different IRs. At https://github.com/tlc-pack/relax/tree/relax/python/tvm/script/relax, a mature implementation of TVMScript parser is maintained for Relax. But it’s hard to extend if we want to support more IRs for TVM unity.

To conclude, with this RFC, we want to:

  1. Add more metaprogramming features to TVMScript, making it easier for TVM developers to write complicated operators.
  2. Improve tooling and documentation of TVMScript, reducing the friction for an average machine learning practitioner to use TVMScript.
  3. Modularize and infrastructuralize the TVMScript parser, lowering the cost to implement parser for new IR.

Guide-level explanation

Metaprogramming features to support

(F1) Template Metaprogramming

Users should be able to use variables from outer scope in the TVMScript function/class. The parsed result should be identical to function/class with the variable replaced by its value. For instance,

@T.prim_func
def matmul(
  A: T.Buffer[(128, 128)],
) -> None:
  ...

def gen_matmul(n, m) -> None:
  @T.prim_func
  def f(A: T.Buffer[(n, m)]):
    ...
  return f

f = gen_matmul(n=128, m=128) # `f` should be identical to `matmul`

This is already partially supported by https://github.com/apache/tvm/pull/11097 for using PrimExpr captured by outer function. With the new parser, we want to support this feature in more places and with more variable types.

(F2) Rank-polymorphism

Users should be able to write a single function to handle different ranks of input buffers (different numbers of dimensions). For example, user should be able to write a generic function to do broadcast add,

def broadcast_add(a, b, c):
  @T.prim_func
  def f(
    A: T.BufferFrom(a),
    B: T.BufferFrom(b),
    C: T.BufferFrom(c),
  ) -> None:
    for i, i_a, i_b in T.some_broadcast_method(A.shape, B.shape):
      with T.block():
        C[*i] = A[*i_a] + B[*i_b]

broadcast_add(
  a = Buffer((128, 1), "float32"),
  b = Buffer((1, 128), "float32"),
  c = Buffer((128, 128), "float32"),
)

(F3) Sugar: TE Compute in TIR

Users should be able to replace boilerplate code with a function call, which’s expanded to large chunk of code during parsing. For example, we may want to use TE’s compute-like syntax to replace nested loop,

@T.prim_func
def te_compute_sugar(
  A: T.Buffer[(128, 128)],
  B: T.Buffer[(128, 128)],
) -> None:
  ...
  C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
  ...

## expands to ====>

@T.prim_func
def te_compute_expanded(
  A: T.Buffer[(128, 128)],
  B: T.Buffer[(128, 128)],
) -> None:
  ...
  for i in range(128):
    for j in range(128):
      with T.block("..."):
        C[i, j] = A[i, j] + B[i, j]
  ...

(F4) Interleave host program and TVMScript program to customize metaprogramming

As an escape hatch from writing code to be parsed by the TVMScript parser, users should be able to write imperative code to construct IR nodes directly and embed it inside regular TVMScript. Those code will be evaluated by the Python interpreter when parsing. This gives users the ultimate tool when TVMScript isn’t expressible enough for their use cases. For example, at python/tvm/topi/vision/nms.py#L380-L431, there are blocks of repetitive code on computing the coordinates of the four corners of bounding box. This can be simplified as:

# Before, without IRBuilder interleaving
@T.prim_func
def nms(...):
  ...
  for i in range(batch_size):
    ...
    a_l = min(
      output[batch_idx, box_a_idx, box_start_idx],
      output[batch_idx, box_a_idx, box_start_idx + 2],
    )
    a_t = min(
      output[batch_idx, box_a_idx, box_start_idx + 1],
      output[batch_idx, box_a_idx, box_start_idx + 3],
    )
    a_r = max(
      output[batch_idx, box_a_idx, box_start_idx],
      output[batch_idx, box_a_idx, box_start_idx + 2],
    )
    a_b = max(
      output[batch_idx, box_a_idx, box_start_idx + 1],
      output[batch_idx, box_a_idx, box_start_idx + 3],
    )
		...
    for k in range(j):
      check_iou = ...
			...
      if check_iou > 0:
        # b_l: left, b_t: top, b_r: right, b_b: bottom
        b_l = min(
          output[batch_idx, box_b_idx, box_start_idx],
          output[batch_idx, box_b_idx, box_start_idx + 2],
        )
        b_t = min(
          output[batch_idx, box_b_idx, box_start_idx + 1],
          output[batch_idx, box_b_idx, box_start_idx + 3],
        )
        b_r = max(
          output[batch_idx, box_b_idx, box_start_idx],
          output[batch_idx, box_b_idx, box_start_idx + 2],
        )
        b_b = max(
          output[batch_idx, box_b_idx, box_start_idx + 1],
          output[batch_idx, box_b_idx, box_start_idx + 3],
        )
        ...

# With IRBuilder interleaving:

from tvm.script import tir as T

def get_box_coordinates(output, batch_idx, box_idx, box_start_idx):
  """a method executed by python interpreter"""
  box_l = T.min(
    output[batch_idx, box_idx, box_start_idx],
    output[batch_idx, box_idx, box_start_idx + 2],
	) # type(box_l) is PrimExpr
  ... # Repeat for other coordinates
  return box_l, box_t, box_r, box_b

@T.prim_func(capture=[get_box_coordinates])
def nms(...):
  ...
  for i in range(batch_size):
    ...
    a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
    ...
    for k in range(j):
      check_iou = ...
      ...
      if check_iou > 0:
        b_l, b_t, b_r, b_b = get_box_coordinates(output, batch_idx, box_b_idx, box_start_idx)
        ...

Reference-level explanation

IRBuilder as Core

As the foundation of IR construction, we will provide a set of APIs called IRBuilder to let user construct IR imperatively. IRBuilder will be used by the parser, as well as by users directly as described in the feature F4. IRBuilder allows user to write code in a style that’s similar to TVMScript, while it’s being executed as host program. For example,

from tvm.script.builder import Builder, def_, def_many
from tvm.script import tir as T

with Builder() as b:
  with T.prim_func():
    T.func_name("main")
    buffer_a = T.Buffer((128, 128, 128), "float32")
    buffer_b = T.Buffer((128, 128, 128), "float32")
    arg_a = T.arg("A", buffer_a)
    arg_b = T.arg("B", buffer_b)
    with T.grid(128, 128, 128) as (i, j, k):
      def_many(["i", "j", "k"], [i, j, k])
      with T.block(name="block"):
        vi = def_("vi", T.axis.spatial(128, i))
        vj = def_("vj", T.axis.spatial(128, j))
        vk = def_("vk", T.axis.reduce(128, k))

f = b.get() # f is a PrimFunc

produces equivalent result as

@T.prim_func
def main(
  A: T.Buffer(shape=(128, 128, 128), dtype="float32"),
  B: T.Buffer(shape=(128, 128, 128), dtype="float32"),
) -> None:
  for i, j, k in T.grid(128, 128, 128):
      with T.block("block"):
          vi = T.axis.S(128, i)
          vj = T.axis.S(128, j)
          vk = T.axis.R(128, k)

As shown in the example above, user doesn't need to pass the builder b to subsequent calls to IRBuilder API. The current builder state is maintained in a threadlocal store to improve the ergonomics of IRBuilder API by avoiding passing the builder state explicitly.

The implementation of IRBuilder will be in C++ so that it can be used in an environment without Python. Python binding will be created to expose IRBuilder to TVMScript parser.

With the separation between IRBuilder and parser, most implementation and documentation can be reused between TVMScript and IR definition. For example, most of operators are simply imported into the IRBuilder package, like

from tvm.tir import sin, cos

so the documentation and type signatures only need to be written once, and the APIs are guaranteed to be consistent.

Parse-time evaluation

TVMScript Parser can be considered as a thin layer built above the IRBuilder API. The parser transforms input AST into a sequence of calls to IRBuilder API, by evaluating small fragments of code as it visits AST. The IRBuilder is responsible for building the actual IR graph. All metaprogramming features we discussed above (F1 through F4) can be implemented through this parse-time evaluation in a consistent manner. Using the same gen_matmul example from F1,

def gen_matmul(n, m) -> None:
  @T.prim_func
  def f(A: T.Buffer[(n, m)]):
    ...
  return f

What parser does here is to:

  1. Collect the environment inside gen_matmul, getting a dictionary
    1. All primitive types (int, str, float and None) will be captured automatically, while advanced types (like function) needs to be explicitly declared in the decorator to be captured (for example, @T.prim_func(capture=[get_box_coordinates]))
  2. Call the corresponding function from IRBuilder, as the parser visits the AST of function f.
    1. When visiting the function argument, call eval on its type annotation, with the environment captured in the first step.
    2. T.Buffer[(n, m)] gets evaluated to a value with type tir.Buffer.
    3. Call the IRBuilder API T.arg("A", buffer) to add an arg to the function that’s being constructed

Another example,

for *i in T.grid(*A.shape):
  ...

The parser will:

  1. Evaluate the expression T.grid(*A.shape) by the step described above. T.grid returns a value that is nearly equivalent to List[Var].
  2. Call exec on a specially constructed statement *i = __tvm_rhs_var__, with locals that maps __tvm_rhs_var__ to the value evaluated in step 1.
  3. Collect the value of i from the locals dictionary
  4. Call the IRBuilder API def_many(["i"], [i])

As mentioned above, all metaprogramming features (F1 through F4) can be implemented through this parse-time evaluation. It's straightforward to see how F1 and F2 are implemented by parse-time evaluation, but it might be harder to grasp the idea behind F3 and F4.

For F3 (TE Compute in TIR),

C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
# build a similar graph as
for i in range(128):
  for j in range(128):
    with T.block("..."):
      C[i, j] = A[i, j] + B[i, j]

T.compute is provided in IRBuilder API and will construct all IR nodes (For, BlockRealize, Block and BufferStore). T.compute calls the lambda function to get the rhs of BufferStore. Then T.compute returns a Buffer node that represents C. The parser handles the assignment by assigning "C" to the name_hint of the returned buffer (by calling IRBuilder API def_("C", C)), and put it into the internal variable table (which is used to resolve variable when evaluating subsequent statements and expressions).

For F4 (Interleave host program and TVMScript program),

a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)

The call to the get_box_coordinates function is evaluated when parser is visiting the assign statement. The parser calls IRBuilder def_many(["a_l", "a_t", "a_r", "a_b"], <returned_tuple>) and put them into the internal variable table.

Note that we will not place extra restriction on the signature of user-provided function (get_box_coordinates in this example). More precisely, the function argument types can be anything because parser is able to capture outter variables thus bringing variables with arbitrary type into the scope. The restriction on returned type depends on the language spec of the target IR. For example, in TIR user can write A[i] = ... to represent BufferStore, where the rhs is a PrimExpr, then the user can provide a custom function compute_magic_number(index: PrimExpr) -> PrimExpr to use on the rhs as A[i] = compute_magic_number(i).

By running eval and exec provided by the Python interpreter, we can implement language features which are difficult to implement manually, and also make sure TVMScript has the same semantics on expression compared to regular Python code.

Parser Registration

The logic of how to process a particular Python syntax node is registered through decorator. For example,

@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
    ...

handles the transformation of Python For node. The token="tir" in the decorator means that the handler is for TIR. self: Parser has all the infrastructural API and maintains a stack of token to determine which function to dispatch to. This makes embedding different IR possible (for example, embedding TIR in Relax). The folder structure will look like

python/tvm/script/
└── parser
    ├── core
    │   └── ... # Parser infrastructure
    ├── tir # TIR dialect
    │   └── ...
    └── relax # Hypothetical Relax dialect (not part of our RFC)
        └── ...

Usage of eval and exec

The parser uses eval and exec in the following places:

  • It calls eval to evaluate fragment of expressions
  • It calls exec to evaluate different kinds of assignment statement, like first, *rest, last = T.grid(...)

The usage of eval and exec is necessary to our implementation. TVMScript allows users to construct IR graph in TVM declaratively as if they were writing Python code, lowering the barrier of using TVM to do low-level customization. However it is still restrictive and does not allow the usage of many Python features to build abstractions for user's code. All features proposed in this RFC can be seen as an effort to narrow this gap, thus they are designed to follow the Python semantics. On the implementation side, the most robust approach is to leverage the Python interpreter itself to facilitate those features, rather than write our own version of restricted Python interpreter. And eval and exec are the most suitable choices to achieve this. Other mechanism, like multiprocess + IPC and subinterpreter, either lacks an easy path to exchange Python objects, or requires dependency on the C API of CPython.

In our use cases, eval and exec do not create additional security risk for users. All inputs to eval come from user's code directly, without modification. Our usage of exec only executes a specific form of code,

<lhs> = __tvm_rhs_var__

where <lhs> is the tokens from the left hand side of assign statement, directly from user‘s code. The situation is very different from cases that make eval and exec infamous, where they are used to evaluate untrusted input from end users, typically in a web server. Furthermore, we require users to explicitly capture variables. Therefore, The evaluation will only involve objects from tvm, Python builtins and objects explicitly captured by users. This rules out the possibility that an external function is called by parser without user’s acknowledgement. If the malicious actor wants to exploit the system through the eval or exec in TVMScript parser, they must first get another RCE (remote code execution) vulnerability in the Python runtime to modify the code in runtime, which makes such exploit useless (one needs to first have an RCE to exploit another RCE with the same exposure).

Drawbacks

N/A

Rationale and alternatives

N/A

Prior art

Prior works in TVM

Other libraries

Taichi

https://www.taichi-lang.org

Taichi has a very similar metaprogramming model (link to doc) as we presented in this RFC. The biggest difference is that Taichi requires ti.static to be wrapped around everything that needs to be evaluated in compile time. It also has advanced features like loop unrolling, compile time branching and compile-time recursion.

In TVMScript parser, it does not need special marker to denote compile-time evaluation. Expressions are consistently evaluated in compile time (evaluated to IR node like PrimExpr, rather than the concrete value like a float matrix.), thanks to the separation of IRBuilder and parser. Features like loop unrolling can be implemented in the IRBuilder layer per target IR. This keeps the core parser as minimal as possible.

Triton

http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf

Triton uses the meta-parameters to generalize kernels (for example, tutorials/03-matrix-multiplication.html). Meta parameters are placed together with real parameters, but with type annotation tl.constexpr to differentiate. This method slightly deviates from regular Python semantics, as users will intuitively expect to pass them together with real parameters when calling the kernel.

In TVMScript, one of the design principle is to narrow the gap between TVMScript and regular Python code. TVMScript should not surprise users with syntax or language features that deviate from Python. All features proposed in this RFC are designed to strictly follow the semantics of Python and aimed to be intuitive to Python users.

Unresolved questions

N/A

Future possibilities

N/A