This RFC proposes a new TVMScript parser infrastructure, supporting extensive metaprogramming and syntactic sugars. The new infrastructure is IR-agnostic, treating TIR just as one of dialects. Additionally, the new infrastructure will provide better tooling around Python ecosystem (pylint, mypy, etc.).
What is TVMScript. Check Blitz Course to TensorIR and TVMScript Unified Printer RFC for an introduction into TVMScript.
What is metaprogramming. In the context of TVMScript, metaprogramming means a programmable way to control IR generation. For example, in https://github.com/apache/tvm/pull/11097, a metaprogramming feature was added to the TVMScript parser, allows users to programmably control the shapes of the input buffers of a PrimFunc
.
The current parser lacks capability on generic metaprogramming that allows user to have more control on IR construction. This makes it challenging to support operators like NMS (non-maximum suppression, which is crucial to object detection model). There is an implementation of NMS at python/tvm/topi/cuda/nms.py#L367-L386. The implementation of NMS-like operators requires rank-polymorphism and the ability to interleave host program with TVMScript, which is difficult to be implemented under the current design.
TVMScript also needs reasonable support on Python tooling. Currently it doesn’t play nicely with pylint and mypy. For example, test_meta_schedule_postproc_rewrite_tensorize.py has 100+ warnings from pylint within only 500 hundred lines of code. This creates confusion to the user and leaves an impression that TVMScript isn’t a mature product and not production-ready. Even though it’s something that can be incrementally improved under the current design, we believe it’s easier to get an ideal result if we have a design with the tooling support in mind.
The current design also lacks of unified approach for different IRs. At https://github.com/tlc-pack/relax/tree/relax/python/tvm/script/relax, a mature implementation of TVMScript parser is maintained for Relax. But it’s hard to extend if we want to support more IRs for TVM unity.
To conclude, with this RFC, we want to:
Users should be able to use variables from outer scope in the TVMScript function/class. The parsed result should be identical to function/class with the variable replaced by its value. For instance,
@T.prim_func def matmul( A: T.Buffer[(128, 128)], ) -> None: ... def gen_matmul(n, m) -> None: @T.prim_func def f(A: T.Buffer[(n, m)]): ... return f f = gen_matmul(n=128, m=128) # `f` should be identical to `matmul`
This is already partially supported by https://github.com/apache/tvm/pull/11097 for using PrimExpr
captured by outer function. With the new parser, we want to support this feature in more places and with more variable types.
Users should be able to write a single function to handle different ranks of input buffers (different numbers of dimensions). For example, user should be able to write a generic function to do broadcast add,
def broadcast_add(a, b, c): @T.prim_func def f( A: T.BufferFrom(a), B: T.BufferFrom(b), C: T.BufferFrom(c), ) -> None: for i, i_a, i_b in T.some_broadcast_method(A.shape, B.shape): with T.block(): C[*i] = A[*i_a] + B[*i_b] broadcast_add( a = Buffer((128, 1), "float32"), b = Buffer((1, 128), "float32"), c = Buffer((128, 128), "float32"), )
Users should be able to replace boilerplate code with a function call, which’s expanded to large chunk of code during parsing. For example, we may want to use TE’s compute-like syntax to replace nested loop,
@T.prim_func def te_compute_sugar( A: T.Buffer[(128, 128)], B: T.Buffer[(128, 128)], ) -> None: ... C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) ... ## expands to ====> @T.prim_func def te_compute_expanded( A: T.Buffer[(128, 128)], B: T.Buffer[(128, 128)], ) -> None: ... for i in range(128): for j in range(128): with T.block("..."): C[i, j] = A[i, j] + B[i, j] ...
As an escape hatch from writing code to be parsed by the TVMScript parser, users should be able to write imperative code to construct IR nodes directly and embed it inside regular TVMScript. Those code will be evaluated by the Python interpreter when parsing. This gives users the ultimate tool when TVMScript isn’t expressible enough for their use cases. For example, at python/tvm/topi/vision/nms.py#L380-L431, there are blocks of repetitive code on computing the coordinates of the four corners of bounding box. This can be simplified as:
# Before, without IRBuilder interleaving @T.prim_func def nms(...): ... for i in range(batch_size): ... a_l = min( output[batch_idx, box_a_idx, box_start_idx], output[batch_idx, box_a_idx, box_start_idx + 2], ) a_t = min( output[batch_idx, box_a_idx, box_start_idx + 1], output[batch_idx, box_a_idx, box_start_idx + 3], ) a_r = max( output[batch_idx, box_a_idx, box_start_idx], output[batch_idx, box_a_idx, box_start_idx + 2], ) a_b = max( output[batch_idx, box_a_idx, box_start_idx + 1], output[batch_idx, box_a_idx, box_start_idx + 3], ) ... for k in range(j): check_iou = ... ... if check_iou > 0: # b_l: left, b_t: top, b_r: right, b_b: bottom b_l = min( output[batch_idx, box_b_idx, box_start_idx], output[batch_idx, box_b_idx, box_start_idx + 2], ) b_t = min( output[batch_idx, box_b_idx, box_start_idx + 1], output[batch_idx, box_b_idx, box_start_idx + 3], ) b_r = max( output[batch_idx, box_b_idx, box_start_idx], output[batch_idx, box_b_idx, box_start_idx + 2], ) b_b = max( output[batch_idx, box_b_idx, box_start_idx + 1], output[batch_idx, box_b_idx, box_start_idx + 3], ) ... # With IRBuilder interleaving: from tvm.script import tir as T def get_box_coordinates(output, batch_idx, box_idx, box_start_idx): """a method executed by python interpreter""" box_l = T.min( output[batch_idx, box_idx, box_start_idx], output[batch_idx, box_idx, box_start_idx + 2], ) # type(box_l) is PrimExpr ... # Repeat for other coordinates return box_l, box_t, box_r, box_b @T.prim_func(capture=[get_box_coordinates]) def nms(...): ... for i in range(batch_size): ... a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx) ... for k in range(j): check_iou = ... ... if check_iou > 0: b_l, b_t, b_r, b_b = get_box_coordinates(output, batch_idx, box_b_idx, box_start_idx) ...
As the foundation of IR construction, we will provide a set of APIs called IRBuilder to let user construct IR imperatively. IRBuilder will be used by the parser, as well as by users directly as described in the feature F4. IRBuilder allows user to write code in a style that’s similar to TVMScript, while it’s being executed as host program. For example,
from tvm.script.builder import Builder, def_, def_many from tvm.script import tir as T with Builder() as b: with T.prim_func(): T.func_name("main") buffer_a = T.Buffer((128, 128, 128), "float32") buffer_b = T.Buffer((128, 128, 128), "float32") arg_a = T.arg("A", buffer_a) arg_b = T.arg("B", buffer_b) with T.grid(128, 128, 128) as (i, j, k): def_many(["i", "j", "k"], [i, j, k]) with T.block(name="block"): vi = def_("vi", T.axis.spatial(128, i)) vj = def_("vj", T.axis.spatial(128, j)) vk = def_("vk", T.axis.reduce(128, k)) f = b.get() # f is a PrimFunc
produces equivalent result as
@T.prim_func def main( A: T.Buffer(shape=(128, 128, 128), dtype="float32"), B: T.Buffer(shape=(128, 128, 128), dtype="float32"), ) -> None: for i, j, k in T.grid(128, 128, 128): with T.block("block"): vi = T.axis.S(128, i) vj = T.axis.S(128, j) vk = T.axis.R(128, k)
As shown in the example above, user doesn't need to pass the builder b
to subsequent calls to IRBuilder API. The current builder state is maintained in a threadlocal store to improve the ergonomics of IRBuilder API by avoiding passing the builder state explicitly.
The implementation of IRBuilder will be in C++ so that it can be used in an environment without Python. Python binding will be created to expose IRBuilder to TVMScript parser.
With the separation between IRBuilder and parser, most implementation and documentation can be reused between TVMScript and IR definition. For example, most of operators are simply imported into the IRBuilder package, like
from tvm.tir import sin, cos
so the documentation and type signatures only need to be written once, and the APIs are guaranteed to be consistent.
TVMScript Parser can be considered as a thin layer built above the IRBuilder API. The parser transforms input AST into a sequence of calls to IRBuilder API, by evaluating small fragments of code as it visits AST. The IRBuilder is responsible for building the actual IR graph. All metaprogramming features we discussed above (F1 through F4) can be implemented through this parse-time evaluation in a consistent manner. Using the same gen_matmul
example from F1,
def gen_matmul(n, m) -> None: @T.prim_func def f(A: T.Buffer[(n, m)]): ... return f
What parser does here is to:
gen_matmul
, getting a dictionaryint
, str
, float
and None
) will be captured automatically, while advanced types (like function) needs to be explicitly declared in the decorator to be captured (for example, @T.prim_func(capture=[get_box_coordinates])
)f
.eval
on its type annotation, with the environment captured in the first step.T.Buffer[(n, m)]
gets evaluated to a value with type tir.Buffer
.T.arg("A", buffer)
to add an arg to the function that’s being constructedAnother example,
for *i in T.grid(*A.shape): ...
The parser will:
T.grid(*A.shape)
by the step described above. T.grid
returns a value that is nearly equivalent to List[Var]
.exec
on a specially constructed statement *i = __tvm_rhs_var__
, with locals
that maps __tvm_rhs_var__
to the value evaluated in step 1.i
from the locals
dictionarydef_many(["i"], [i])
As mentioned above, all metaprogramming features (F1 through F4) can be implemented through this parse-time evaluation. It's straightforward to see how F1 and F2 are implemented by parse-time evaluation, but it might be harder to grasp the idea behind F3 and F4.
For F3 (TE Compute in TIR),
C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) # build a similar graph as for i in range(128): for j in range(128): with T.block("..."): C[i, j] = A[i, j] + B[i, j]
T.compute
is provided in IRBuilder API and will construct all IR nodes (For, BlockRealize, Block and BufferStore). T.compute
calls the lambda function to get the rhs of BufferStore. Then T.compute
returns a Buffer
node that represents C
. The parser handles the assignment by assigning "C"
to the name_hint
of the returned buffer (by calling IRBuilder API def_("C", C)
), and put it into the internal variable table (which is used to resolve variable when evaluating subsequent statements and expressions).
For F4 (Interleave host program and TVMScript program),
a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
The call to the get_box_coordinates
function is evaluated when parser is visiting the assign statement. The parser calls IRBuilder def_many(["a_l", "a_t", "a_r", "a_b"], <returned_tuple>)
and put them into the internal variable table.
Note that we will not place extra restriction on the signature of user-provided function (get_box_coordinates
in this example). More precisely, the function argument types can be anything because parser is able to capture outter variables thus bringing variables with arbitrary type into the scope. The restriction on returned type depends on the language spec of the target IR. For example, in TIR user can write A[i] = ...
to represent BufferStore
, where the rhs is a PrimExpr
, then the user can provide a custom function compute_magic_number(index: PrimExpr) -> PrimExpr
to use on the rhs as A[i] = compute_magic_number(i)
.
By running eval
and exec
provided by the Python interpreter, we can implement language features which are difficult to implement manually, and also make sure TVMScript has the same semantics on expression compared to regular Python code.
The logic of how to process a particular Python syntax node is registered through decorator. For example,
@dispatch.register(token="tir", type_name="For") def visit_for(self: Parser, node: doc.For) -> None: ...
handles the transformation of Python For
node. The token="tir"
in the decorator means that the handler is for TIR. self: Parser
has all the infrastructural API and maintains a stack of token
to determine which function to dispatch to. This makes embedding different IR possible (for example, embedding TIR in Relax). The folder structure will look like
python/tvm/script/ └── parser ├── core │ └── ... # Parser infrastructure ├── tir # TIR dialect │ └── ... └── relax # Hypothetical Relax dialect (not part of our RFC) └── ...
eval
and exec
The parser uses eval
and exec
in the following places:
eval
to evaluate fragment of expressionsexec
to evaluate different kinds of assignment statement, like first, *rest, last = T.grid(...)
The usage of eval
and exec
is necessary to our implementation. TVMScript allows users to construct IR graph in TVM declaratively as if they were writing Python code, lowering the barrier of using TVM to do low-level customization. However it is still restrictive and does not allow the usage of many Python features to build abstractions for user's code. All features proposed in this RFC can be seen as an effort to narrow this gap, thus they are designed to follow the Python semantics. On the implementation side, the most robust approach is to leverage the Python interpreter itself to facilitate those features, rather than write our own version of restricted Python interpreter. And eval
and exec
are the most suitable choices to achieve this. Other mechanism, like multiprocess + IPC and subinterpreter, either lacks an easy path to exchange Python objects, or requires dependency on the C API of CPython.
In our use cases, eval
and exec
do not create additional security risk for users. All inputs to eval
come from user's code directly, without modification. Our usage of exec
only executes a specific form of code,
<lhs> = __tvm_rhs_var__
where <lhs>
is the tokens from the left hand side of assign statement, directly from user‘s code. The situation is very different from cases that make eval
and exec
infamous, where they are used to evaluate untrusted input from end users, typically in a web server. Furthermore, we require users to explicitly capture variables. Therefore, The evaluation will only involve objects from tvm
, Python builtins and objects explicitly captured by users. This rules out the possibility that an external function is called by parser without user’s acknowledgement. If the malicious actor wants to exploit the system through the eval
or exec
in TVMScript parser, they must first get another RCE (remote code execution) vulnerability in the Python runtime to modify the code in runtime, which makes such exploit useless (one needs to first have an RCE to exploit another RCE with the same exposure).
N/A
N/A
Taichi has a very similar metaprogramming model (link to doc) as we presented in this RFC. The biggest difference is that Taichi requires ti.static
to be wrapped around everything that needs to be evaluated in compile time. It also has advanced features like loop unrolling, compile time branching and compile-time recursion.
In TVMScript parser, it does not need special marker to denote compile-time evaluation. Expressions are consistently evaluated in compile time (evaluated to IR node like PrimExpr, rather than the concrete value like a float matrix.), thanks to the separation of IRBuilder and parser. Features like loop unrolling can be implemented in the IRBuilder layer per target IR. This keeps the core parser as minimal as possible.
http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf
Triton uses the meta-parameters to generalize kernels (for example, tutorials/03-matrix-multiplication.html). Meta parameters are placed together with real parameters, but with type annotation tl.constexpr
to differentiate. This method slightly deviates from regular Python semantics, as users will intuitively expect to pass them together with real parameters when calling the kernel.
In TVMScript, one of the design principle is to narrow the gap between TVMScript and regular Python code. TVMScript should not surprise users with syntax or language features that deviate from Python. All features proposed in this RFC are designed to strictly follow the semantics of Python and aimed to be intuitive to Python users.
N/A
N/A